initial commit
This commit is contained in:
@@ -0,0 +1,95 @@
|
||||
"""**Graphs** provide a natural language interface to graph databases."""
|
||||
|
||||
import importlib
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_community.graphs.arangodb_graph import (
|
||||
ArangoGraph,
|
||||
)
|
||||
from langchain_community.graphs.falkordb_graph import (
|
||||
FalkorDBGraph,
|
||||
)
|
||||
from langchain_community.graphs.gremlin_graph import (
|
||||
GremlinGraph,
|
||||
)
|
||||
from langchain_community.graphs.hugegraph import (
|
||||
HugeGraph,
|
||||
)
|
||||
from langchain_community.graphs.kuzu_graph import (
|
||||
KuzuGraph,
|
||||
)
|
||||
from langchain_community.graphs.memgraph_graph import (
|
||||
MemgraphGraph,
|
||||
)
|
||||
from langchain_community.graphs.nebula_graph import (
|
||||
NebulaGraph,
|
||||
)
|
||||
from langchain_community.graphs.neo4j_graph import (
|
||||
Neo4jGraph,
|
||||
)
|
||||
from langchain_community.graphs.neptune_graph import (
|
||||
BaseNeptuneGraph,
|
||||
NeptuneAnalyticsGraph,
|
||||
NeptuneGraph,
|
||||
)
|
||||
from langchain_community.graphs.neptune_rdf_graph import (
|
||||
NeptuneRdfGraph,
|
||||
)
|
||||
from langchain_community.graphs.networkx_graph import (
|
||||
NetworkxEntityGraph,
|
||||
)
|
||||
from langchain_community.graphs.ontotext_graphdb_graph import (
|
||||
OntotextGraphDBGraph,
|
||||
)
|
||||
from langchain_community.graphs.rdf_graph import (
|
||||
RdfGraph,
|
||||
)
|
||||
from langchain_community.graphs.tigergraph_graph import (
|
||||
TigerGraph,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ArangoGraph",
|
||||
"FalkorDBGraph",
|
||||
"GremlinGraph",
|
||||
"HugeGraph",
|
||||
"KuzuGraph",
|
||||
"BaseNeptuneGraph",
|
||||
"MemgraphGraph",
|
||||
"NebulaGraph",
|
||||
"Neo4jGraph",
|
||||
"NeptuneGraph",
|
||||
"NeptuneRdfGraph",
|
||||
"NeptuneAnalyticsGraph",
|
||||
"NetworkxEntityGraph",
|
||||
"OntotextGraphDBGraph",
|
||||
"RdfGraph",
|
||||
"TigerGraph",
|
||||
]
|
||||
|
||||
_module_lookup = {
|
||||
"ArangoGraph": "langchain_community.graphs.arangodb_graph",
|
||||
"FalkorDBGraph": "langchain_community.graphs.falkordb_graph",
|
||||
"GremlinGraph": "langchain_community.graphs.gremlin_graph",
|
||||
"HugeGraph": "langchain_community.graphs.hugegraph",
|
||||
"KuzuGraph": "langchain_community.graphs.kuzu_graph",
|
||||
"MemgraphGraph": "langchain_community.graphs.memgraph_graph",
|
||||
"NebulaGraph": "langchain_community.graphs.nebula_graph",
|
||||
"Neo4jGraph": "langchain_community.graphs.neo4j_graph",
|
||||
"BaseNeptuneGraph": "langchain_community.graphs.neptune_graph",
|
||||
"NeptuneAnalyticsGraph": "langchain_community.graphs.neptune_graph",
|
||||
"NeptuneGraph": "langchain_community.graphs.neptune_graph",
|
||||
"NeptuneRdfGraph": "langchain_community.graphs.neptune_rdf_graph",
|
||||
"NetworkxEntityGraph": "langchain_community.graphs.networkx_graph",
|
||||
"OntotextGraphDBGraph": "langchain_community.graphs.ontotext_graphdb_graph",
|
||||
"RdfGraph": "langchain_community.graphs.rdf_graph",
|
||||
"TigerGraph": "langchain_community.graphs.tigergraph_graph",
|
||||
}
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
if name in _module_lookup:
|
||||
module = importlib.import_module(_module_lookup[name])
|
||||
return getattr(module, name)
|
||||
raise AttributeError(f"module {__name__} has no attribute {name}")
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
765
venv/Lib/site-packages/langchain_community/graphs/age_graph.py
Normal file
765
venv/Lib/site-packages/langchain_community/graphs/age_graph.py
Normal file
@@ -0,0 +1,765 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from hashlib import md5
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Pattern, Tuple, Union
|
||||
|
||||
from langchain_community.graphs.graph_document import GraphDocument
|
||||
from langchain_community.graphs.graph_store import GraphStore
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import psycopg2.extras
|
||||
|
||||
|
||||
class AGEQueryException(Exception):
|
||||
"""Exception for the AGE queries."""
|
||||
|
||||
def __init__(self, exception: Union[str, Dict]) -> None:
|
||||
if isinstance(exception, dict):
|
||||
self.message = exception["message"] if "message" in exception else "unknown"
|
||||
self.details = exception["details"] if "details" in exception else "unknown"
|
||||
else:
|
||||
self.message = exception
|
||||
self.details = "unknown"
|
||||
|
||||
def get_message(self) -> str:
|
||||
return self.message
|
||||
|
||||
def get_details(self) -> Any:
|
||||
return self.details
|
||||
|
||||
|
||||
class AGEGraph(GraphStore):
|
||||
"""
|
||||
Apache AGE wrapper for graph operations.
|
||||
|
||||
Args:
|
||||
graph_name (str): the name of the graph to connect to or create
|
||||
conf (Dict[str, Any]): the pgsql connection config passed directly
|
||||
to psycopg2.connect
|
||||
create (bool): if True and graph doesn't exist, attempt to create it
|
||||
|
||||
*Security note*: Make sure that the database connection uses credentials
|
||||
that are narrowly-scoped to only include necessary permissions.
|
||||
Failure to do so may result in data corruption or loss, since the calling
|
||||
code may attempt commands that would result in deletion, mutation
|
||||
of data if appropriately prompted or reading sensitive data if such
|
||||
data is present in the database.
|
||||
The best way to guard against such negative outcomes is to (as appropriate)
|
||||
limit the permissions granted to the credentials used with this tool.
|
||||
|
||||
See https://python.langchain.com/docs/security for more information.
|
||||
"""
|
||||
|
||||
# python type mapping for providing readable types to LLM
|
||||
types = {
|
||||
"str": "STRING",
|
||||
"float": "DOUBLE",
|
||||
"int": "INTEGER",
|
||||
"list": "LIST",
|
||||
"dict": "MAP",
|
||||
"bool": "BOOLEAN",
|
||||
}
|
||||
|
||||
# precompiled regex for checking chars in graph labels
|
||||
label_regex: Pattern = re.compile("[^0-9a-zA-Z]+")
|
||||
|
||||
def __init__(
|
||||
self, graph_name: str, conf: Dict[str, Any], create: bool = True
|
||||
) -> None:
|
||||
"""Create a new AGEGraph instance."""
|
||||
|
||||
self.graph_name = graph_name
|
||||
|
||||
# check that psycopg2 is installed
|
||||
try:
|
||||
import psycopg2
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import psycopg2 python package. "
|
||||
"Please install it with `pip install psycopg2`."
|
||||
)
|
||||
|
||||
self.connection = psycopg2.connect(**conf)
|
||||
|
||||
with self._get_cursor() as curs:
|
||||
# check if graph with name graph_name exists
|
||||
graph_id_query = (
|
||||
"""SELECT graphid FROM ag_catalog.ag_graph WHERE name = '{}'""".format(
|
||||
graph_name
|
||||
)
|
||||
)
|
||||
|
||||
curs.execute(graph_id_query)
|
||||
data = curs.fetchone()
|
||||
|
||||
# if graph doesn't exist and create is True, create it
|
||||
if data is None:
|
||||
if create:
|
||||
create_statement = """
|
||||
SELECT ag_catalog.create_graph('{}');
|
||||
""".format(graph_name)
|
||||
|
||||
try:
|
||||
curs.execute(create_statement)
|
||||
self.connection.commit()
|
||||
except psycopg2.Error as e:
|
||||
raise AGEQueryException(
|
||||
{
|
||||
"message": "Could not create the graph",
|
||||
"detail": str(e),
|
||||
}
|
||||
)
|
||||
|
||||
else:
|
||||
raise Exception(
|
||||
(
|
||||
'Graph "{}" does not exist in the database '
|
||||
+ 'and "create" is set to False'
|
||||
).format(graph_name)
|
||||
)
|
||||
|
||||
curs.execute(graph_id_query)
|
||||
data = curs.fetchone()
|
||||
|
||||
# store graph id and refresh the schema
|
||||
self.graphid = data.graphid
|
||||
self.refresh_schema()
|
||||
|
||||
def _get_cursor(self) -> psycopg2.extras.NamedTupleCursor:
|
||||
"""
|
||||
get cursor, load age extension and set search path
|
||||
"""
|
||||
|
||||
try:
|
||||
import psycopg2.extras
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Unable to import psycopg2, please install with "
|
||||
"`pip install -U psycopg2`."
|
||||
) from e
|
||||
cursor = self.connection.cursor(cursor_factory=psycopg2.extras.NamedTupleCursor)
|
||||
cursor.execute("""LOAD 'age';""")
|
||||
cursor.execute("""SET search_path = ag_catalog, "$user", public;""")
|
||||
return cursor
|
||||
|
||||
def _get_labels(self) -> Tuple[List[str], List[str]]:
|
||||
"""
|
||||
Get all labels of a graph (for both edges and vertices)
|
||||
by querying the graph metadata table directly
|
||||
|
||||
Returns
|
||||
Tuple[List[str]]: 2 lists, the first containing vertex
|
||||
labels and the second containing edge labels
|
||||
"""
|
||||
|
||||
e_labels_records = self.query(
|
||||
"""MATCH ()-[e]-() RETURN collect(distinct label(e)) as labels"""
|
||||
)
|
||||
e_labels = e_labels_records[0]["labels"] if e_labels_records else []
|
||||
|
||||
n_labels_records = self.query(
|
||||
"""MATCH (n) RETURN collect(distinct label(n)) as labels"""
|
||||
)
|
||||
n_labels = n_labels_records[0]["labels"] if n_labels_records else []
|
||||
|
||||
return n_labels, e_labels
|
||||
|
||||
def _get_triples(self, e_labels: List[str]) -> List[Dict[str, str]]:
|
||||
"""
|
||||
Get a set of distinct relationship types (as a list of dicts) in the graph
|
||||
to be used as context by an llm.
|
||||
|
||||
Args:
|
||||
e_labels (List[str]): a list of edge labels to filter for
|
||||
|
||||
Returns:
|
||||
List[Dict[str, str]]: relationships as a list of dicts in the format
|
||||
"{'start':<from_label>, 'type':<edge_label>, 'end':<from_label>}"
|
||||
"""
|
||||
|
||||
# age query to get distinct relationship types
|
||||
try:
|
||||
import psycopg2
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Unable to import psycopg2, please install with "
|
||||
"`pip install -U psycopg2`."
|
||||
) from e
|
||||
triple_query = """
|
||||
SELECT * FROM ag_catalog.cypher('{graph_name}', $$
|
||||
MATCH (a)-[e:`{e_label}`]->(b)
|
||||
WITH a,e,b LIMIT 3000
|
||||
RETURN DISTINCT labels(a) AS from, type(e) AS edge, labels(b) AS to
|
||||
LIMIT 10
|
||||
$$) AS (f agtype, edge agtype, t agtype);
|
||||
"""
|
||||
|
||||
triple_schema = []
|
||||
|
||||
# iterate desired edge types and add distinct relationship types to result
|
||||
with self._get_cursor() as curs:
|
||||
for label in e_labels:
|
||||
q = triple_query.format(graph_name=self.graph_name, e_label=label)
|
||||
try:
|
||||
curs.execute(q)
|
||||
data = curs.fetchall()
|
||||
for d in data:
|
||||
# use json.loads to convert returned
|
||||
# strings to python primitives
|
||||
triple_schema.append(
|
||||
{
|
||||
"start": json.loads(d.f)[0],
|
||||
"type": json.loads(d.edge),
|
||||
"end": json.loads(d.t)[0],
|
||||
}
|
||||
)
|
||||
except psycopg2.Error as e:
|
||||
raise AGEQueryException(
|
||||
{
|
||||
"message": "Error fetching triples",
|
||||
"detail": str(e),
|
||||
}
|
||||
)
|
||||
|
||||
return triple_schema
|
||||
|
||||
def _get_triples_str(self, e_labels: List[str]) -> List[str]:
|
||||
"""
|
||||
Get a set of distinct relationship types (as a list of strings) in the graph
|
||||
to be used as context by an llm.
|
||||
|
||||
Args:
|
||||
e_labels (List[str]): a list of edge labels to filter for
|
||||
|
||||
Returns:
|
||||
List[str]: relationships as a list of strings in the format
|
||||
"(:`<from_label>`)-[:`<edge_label>`]->(:`<to_label>`)"
|
||||
"""
|
||||
|
||||
triples = self._get_triples(e_labels)
|
||||
|
||||
return self._format_triples(triples)
|
||||
|
||||
@staticmethod
|
||||
def _format_triples(triples: List[Dict[str, str]]) -> List[str]:
|
||||
"""
|
||||
Convert a list of relationships from dictionaries to formatted strings
|
||||
to be better readable by an llm
|
||||
|
||||
Args:
|
||||
triples (List[Dict[str,str]]): a list relationships in the form
|
||||
{'start':<from_label>, 'type':<edge_label>, 'end':<from_label>}
|
||||
|
||||
Returns:
|
||||
List[str]: a list of relationships in the form
|
||||
"(:`<from_label>`)-[:`<edge_label>`]->(:`<to_label>`)"
|
||||
"""
|
||||
triple_template = "(:`{start}`)-[:`{type}`]->(:`{end}`)"
|
||||
triple_schema = [triple_template.format(**triple) for triple in triples]
|
||||
|
||||
return triple_schema
|
||||
|
||||
def _get_node_properties(self, n_labels: List[str]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Fetch a list of available node properties by node label to be used
|
||||
as context for an llm
|
||||
|
||||
Args:
|
||||
n_labels (List[str]): a list of node labels to filter for
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: a list of node labels and
|
||||
their corresponding properties in the form
|
||||
"{
|
||||
'labels': <node_label>,
|
||||
'properties': [
|
||||
{
|
||||
'property': <property_name>,
|
||||
'type': <property_type>
|
||||
},...
|
||||
]
|
||||
}"
|
||||
"""
|
||||
try:
|
||||
import psycopg2
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Unable to import psycopg2, please install with "
|
||||
"`pip install -U psycopg2`."
|
||||
) from e
|
||||
|
||||
# cypher query to fetch properties of a given label
|
||||
node_properties_query = """
|
||||
SELECT * FROM ag_catalog.cypher('{graph_name}', $$
|
||||
MATCH (a:`{n_label}`)
|
||||
RETURN properties(a) AS props
|
||||
LIMIT 100
|
||||
$$) AS (props agtype);
|
||||
"""
|
||||
|
||||
node_properties = []
|
||||
with self._get_cursor() as curs:
|
||||
for label in n_labels:
|
||||
q = node_properties_query.format(
|
||||
graph_name=self.graph_name, n_label=label
|
||||
)
|
||||
|
||||
try:
|
||||
curs.execute(q)
|
||||
except psycopg2.Error as e:
|
||||
raise AGEQueryException(
|
||||
{
|
||||
"message": "Error fetching node properties",
|
||||
"detail": str(e),
|
||||
}
|
||||
)
|
||||
data = curs.fetchall()
|
||||
|
||||
# build a set of distinct properties
|
||||
s = set({})
|
||||
for d in data:
|
||||
# use json.loads to convert to python
|
||||
# primitive and get readable type
|
||||
for k, v in json.loads(d.props).items():
|
||||
s.add((k, self.types[type(v).__name__]))
|
||||
|
||||
np = {
|
||||
"properties": [{"property": k, "type": v} for k, v in s],
|
||||
"labels": label,
|
||||
}
|
||||
node_properties.append(np)
|
||||
|
||||
return node_properties
|
||||
|
||||
def _get_edge_properties(self, e_labels: List[str]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Fetch a list of available edge properties by edge label to be used
|
||||
as context for an llm
|
||||
|
||||
Args:
|
||||
e_labels (List[str]): a list of edge labels to filter for
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: a list of edge labels
|
||||
and their corresponding properties in the form
|
||||
"{
|
||||
'labels': <edge_label>,
|
||||
'properties': [
|
||||
{
|
||||
'property': <property_name>,
|
||||
'type': <property_type>
|
||||
},...
|
||||
]
|
||||
}"
|
||||
"""
|
||||
|
||||
try:
|
||||
import psycopg2
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Unable to import psycopg2, please install with "
|
||||
"`pip install -U psycopg2`."
|
||||
) from e
|
||||
# cypher query to fetch properties of a given label
|
||||
edge_properties_query = """
|
||||
SELECT * FROM ag_catalog.cypher('{graph_name}', $$
|
||||
MATCH ()-[e:`{e_label}`]->()
|
||||
RETURN properties(e) AS props
|
||||
LIMIT 100
|
||||
$$) AS (props agtype);
|
||||
"""
|
||||
edge_properties = []
|
||||
with self._get_cursor() as curs:
|
||||
for label in e_labels:
|
||||
q = edge_properties_query.format(
|
||||
graph_name=self.graph_name, e_label=label
|
||||
)
|
||||
|
||||
try:
|
||||
curs.execute(q)
|
||||
except psycopg2.Error as e:
|
||||
raise AGEQueryException(
|
||||
{
|
||||
"message": "Error fetching edge properties",
|
||||
"detail": str(e),
|
||||
}
|
||||
)
|
||||
data = curs.fetchall()
|
||||
|
||||
# build a set of distinct properties
|
||||
s = set({})
|
||||
for d in data:
|
||||
# use json.loads to convert to python
|
||||
# primitive and get readable type
|
||||
for k, v in json.loads(d.props).items():
|
||||
s.add((k, self.types[type(v).__name__]))
|
||||
|
||||
np = {
|
||||
"properties": [{"property": k, "type": v} for k, v in s],
|
||||
"type": label,
|
||||
}
|
||||
edge_properties.append(np)
|
||||
|
||||
return edge_properties
|
||||
|
||||
def refresh_schema(self) -> None:
|
||||
"""
|
||||
Refresh the graph schema information by updating the available
|
||||
labels, relationships, and properties
|
||||
"""
|
||||
|
||||
# fetch graph schema information
|
||||
n_labels, e_labels = self._get_labels()
|
||||
triple_schema = self._get_triples(e_labels)
|
||||
|
||||
node_properties = self._get_node_properties(n_labels)
|
||||
edge_properties = self._get_edge_properties(e_labels)
|
||||
|
||||
# update the formatted string representation
|
||||
self.schema = f"""
|
||||
Node properties are the following:
|
||||
{node_properties}
|
||||
Relationship properties are the following:
|
||||
{edge_properties}
|
||||
The relationships are the following:
|
||||
{self._format_triples(triple_schema)}
|
||||
"""
|
||||
|
||||
# update the dictionary representation
|
||||
self.structured_schema = {
|
||||
"node_props": {el["labels"]: el["properties"] for el in node_properties},
|
||||
"rel_props": {el["type"]: el["properties"] for el in edge_properties},
|
||||
"relationships": triple_schema,
|
||||
"metadata": {},
|
||||
}
|
||||
|
||||
@property
|
||||
def get_schema(self) -> str:
|
||||
"""Returns the schema of the Graph"""
|
||||
return self.schema
|
||||
|
||||
@property
|
||||
def get_structured_schema(self) -> Dict[str, Any]:
|
||||
"""Returns the structured schema of the Graph"""
|
||||
return self.structured_schema
|
||||
|
||||
@staticmethod
|
||||
def _get_col_name(field: str, idx: int) -> str:
|
||||
"""
|
||||
Convert a cypher return field to a pgsql select field
|
||||
If possible keep the cypher column name, but create a generic name if necessary
|
||||
|
||||
Args:
|
||||
field (str): a return field from a cypher query to be formatted for pgsql
|
||||
idx (int): the position of the field in the return statement
|
||||
|
||||
Returns:
|
||||
str: the field to be used in the pgsql select statement
|
||||
"""
|
||||
# remove white space
|
||||
field = field.strip()
|
||||
# if an alias is provided for the field, use it
|
||||
if " as " in field:
|
||||
return field.split(" as ")[-1].strip()
|
||||
# if the return value is an unnamed primitive, give it a generic name
|
||||
elif field.isnumeric() or field in ("true", "false", "null"):
|
||||
return f"column_{idx}"
|
||||
# otherwise return the value stripping out some common special chars
|
||||
else:
|
||||
return field.replace("(", "_").replace(")", "")
|
||||
|
||||
@staticmethod
|
||||
def _wrap_query(query: str, graph_name: str) -> str:
|
||||
"""
|
||||
Convert a Cyper query to an Apache Age compatible Sql Query.
|
||||
Handles combined queries with UNION/EXCEPT operators
|
||||
|
||||
Args:
|
||||
query (str) : A valid cypher query, can include UNION/EXCEPT operators
|
||||
graph_name (str) : The name of the graph to query
|
||||
|
||||
Returns :
|
||||
str : An equivalent pgSql query wrapped with ag_catalog.cypher
|
||||
|
||||
Raises:
|
||||
ValueError : If query is empty, contain RETURN *, or has invalid field names
|
||||
"""
|
||||
|
||||
if not query.strip():
|
||||
raise ValueError("Empty query provided")
|
||||
|
||||
# pgsql template
|
||||
template = """SELECT {projection} FROM ag_catalog.cypher('{graph_name}', $$
|
||||
{query}
|
||||
$$) AS ({fields});"""
|
||||
|
||||
# split the query into parts based on UNION and EXCEPT
|
||||
parts = re.split(r"\b(UNION\b|\bEXCEPT)\b", query, flags=re.IGNORECASE)
|
||||
|
||||
all_fields = []
|
||||
|
||||
for part in parts:
|
||||
if part.strip().upper() in ("UNION", "EXCEPT"):
|
||||
continue
|
||||
|
||||
# if there are any returned fields they must be added to the pgsql query
|
||||
return_match = re.search(r'\breturn\b(?![^"]*")', part, re.IGNORECASE)
|
||||
if return_match:
|
||||
# Extract the part of the query after the RETURN keyword
|
||||
return_clause = part[return_match.end() :]
|
||||
|
||||
# parse return statement to identify returned fields
|
||||
fields = (
|
||||
return_clause.lower()
|
||||
.split("distinct")[-1]
|
||||
.split("order by")[0]
|
||||
.split("skip")[0]
|
||||
.split("limit")[0]
|
||||
.split(",")
|
||||
)
|
||||
|
||||
# raise exception if RETURN * is found as we can't resolve the fields
|
||||
clean_fileds = [f.strip() for f in fields if f.strip()]
|
||||
if "*" in clean_fileds:
|
||||
raise ValueError(
|
||||
"Apache Age does not support RETURN * in Cypher queries"
|
||||
)
|
||||
|
||||
# Format fields and maintain order of appearance
|
||||
for idx, field in enumerate(clean_fileds):
|
||||
field_name = AGEGraph._get_col_name(field, idx)
|
||||
if field_name not in all_fields:
|
||||
all_fields.append(field_name)
|
||||
|
||||
# if no return statements found in any part
|
||||
if not all_fields:
|
||||
fields_str = "a agtype"
|
||||
|
||||
else:
|
||||
fields_str = ", ".join(f"{field} agtype" for field in all_fields)
|
||||
|
||||
return template.format(
|
||||
graph_name=graph_name,
|
||||
query=query,
|
||||
fields=fields_str,
|
||||
projection="*",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _record_to_dict(record: NamedTuple) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert a record returned from an age query to a dictionary
|
||||
|
||||
Args:
|
||||
record (): a record from an age query result
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: a dictionary representation of the record where
|
||||
the dictionary key is the field name and the value is the
|
||||
value converted to a python type
|
||||
"""
|
||||
# result holder
|
||||
d = {}
|
||||
|
||||
# prebuild a mapping of vertex_id to vertex mappings to be used
|
||||
# later to build edges
|
||||
vertices = {}
|
||||
for k in record._fields:
|
||||
v = getattr(record, k)
|
||||
# agtype comes back '{key: value}::type' which must be parsed
|
||||
if isinstance(v, str) and "::" in v:
|
||||
dtype = v.split("::")[-1]
|
||||
v = v.split("::")[0]
|
||||
if dtype == "vertex":
|
||||
vertex = json.loads(v)
|
||||
vertices[vertex["id"]] = vertex.get("properties")
|
||||
|
||||
# iterate returned fields and parse appropriately
|
||||
for k in record._fields:
|
||||
v = getattr(record, k)
|
||||
if isinstance(v, str) and "::" in v:
|
||||
dtype = v.split("::")[-1]
|
||||
v = v.split("::")[0]
|
||||
else:
|
||||
dtype = ""
|
||||
|
||||
if dtype == "vertex":
|
||||
d[k] = json.loads(v).get("properties")
|
||||
# convert edge from id-label->id by replacing id with node information
|
||||
# we only do this if the vertex was also returned in the query
|
||||
# this is an attempt to be consistent with neo4j implementation
|
||||
elif dtype == "edge":
|
||||
edge = json.loads(v)
|
||||
d[k] = (
|
||||
vertices.get(edge["start_id"], {}),
|
||||
edge["label"],
|
||||
vertices.get(edge["end_id"], {}),
|
||||
)
|
||||
else:
|
||||
d[k] = json.loads(v) if isinstance(v, str) else v
|
||||
|
||||
return d
|
||||
|
||||
def query(self, query: str, params: dict = {}) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Query the graph by taking a cypher query, converting it to an
|
||||
age compatible query, executing it and converting the result
|
||||
|
||||
Args:
|
||||
query (str): a cypher query to be executed
|
||||
params (dict): parameters for the query (not used in this implementation)
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: a list of dictionaries containing the result set
|
||||
"""
|
||||
try:
|
||||
import psycopg2
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Unable to import psycopg2, please install with "
|
||||
"`pip install -U psycopg2`."
|
||||
) from e
|
||||
|
||||
# convert cypher query to pgsql/age query
|
||||
wrapped_query = self._wrap_query(query, self.graph_name)
|
||||
|
||||
# execute the query, rolling back on an error
|
||||
with self._get_cursor() as curs:
|
||||
try:
|
||||
curs.execute(wrapped_query)
|
||||
self.connection.commit()
|
||||
except psycopg2.Error as e:
|
||||
self.connection.rollback()
|
||||
raise AGEQueryException(
|
||||
{
|
||||
"message": "Error executing graph query: {}".format(query),
|
||||
"detail": str(e),
|
||||
}
|
||||
)
|
||||
|
||||
data = curs.fetchall()
|
||||
if data is None:
|
||||
result = []
|
||||
# convert to dictionaries
|
||||
else:
|
||||
result = [self._record_to_dict(d) for d in data]
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _format_properties(
|
||||
properties: Dict[str, Any], id: Union[str, None] = None
|
||||
) -> str:
|
||||
"""
|
||||
Convert a dictionary of properties to a string representation that
|
||||
can be used in a cypher query insert/merge statement.
|
||||
|
||||
Args:
|
||||
properties (Dict[str,str]): a dictionary containing node/edge properties
|
||||
id (Union[str, None]): the id of the node or None if none exists
|
||||
|
||||
Returns:
|
||||
str: the properties dictionary as a properly formatted string
|
||||
"""
|
||||
props = []
|
||||
# wrap property key in backticks to escape
|
||||
for k, v in properties.items():
|
||||
prop = f"`{k}`: {json.dumps(v)}"
|
||||
props.append(prop)
|
||||
if id is not None and "id" not in properties:
|
||||
props.append(
|
||||
f"id: {json.dumps(id)}" if isinstance(id, str) else f"id: {id}"
|
||||
)
|
||||
return "{" + ", ".join(props) + "}"
|
||||
|
||||
@staticmethod
|
||||
def clean_graph_labels(label: str) -> str:
|
||||
"""
|
||||
remove any disallowed characters from a label and replace with '_'
|
||||
|
||||
Args:
|
||||
label (str): the original label
|
||||
|
||||
Returns:
|
||||
str: the sanitized version of the label
|
||||
"""
|
||||
return re.sub(AGEGraph.label_regex, "_", label)
|
||||
|
||||
def add_graph_documents(
|
||||
self, graph_documents: List[GraphDocument], include_source: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
insert a list of graph documents into the graph
|
||||
|
||||
Args:
|
||||
graph_documents (List[GraphDocument]): the list of documents to be inserted
|
||||
include_source (bool): if True add nodes for the sources
|
||||
with MENTIONS edges to the entities they mention
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# query for inserting nodes
|
||||
node_insert_query = (
|
||||
"""
|
||||
MERGE (n:`{label}` {{`id`: "{id}"}})
|
||||
SET n = {properties}
|
||||
"""
|
||||
if not include_source
|
||||
else """
|
||||
MERGE (n:`{label}` {properties})
|
||||
MERGE (d:Document {d_properties})
|
||||
MERGE (d)-[:MENTIONS]->(n)
|
||||
"""
|
||||
)
|
||||
|
||||
# query for inserting edges
|
||||
edge_insert_query = """
|
||||
MERGE (from:`{f_label}` {f_properties})
|
||||
MERGE (to:`{t_label}` {t_properties})
|
||||
MERGE (from)-[:`{r_label}` {r_properties}]->(to)
|
||||
"""
|
||||
# iterate docs and insert them
|
||||
for doc in graph_documents:
|
||||
# if we are adding sources, create an id for the source
|
||||
if include_source:
|
||||
if not doc.source.metadata.get("id"):
|
||||
doc.source.metadata["id"] = md5(
|
||||
doc.source.page_content.encode("utf-8")
|
||||
).hexdigest()
|
||||
|
||||
# insert entity nodes
|
||||
for node in doc.nodes:
|
||||
node.properties["id"] = node.id
|
||||
if include_source:
|
||||
query = node_insert_query.format(
|
||||
label=node.type,
|
||||
properties=self._format_properties(node.properties),
|
||||
d_properties=self._format_properties(doc.source.metadata),
|
||||
)
|
||||
else:
|
||||
query = node_insert_query.format(
|
||||
label=AGEGraph.clean_graph_labels(node.type),
|
||||
properties=self._format_properties(node.properties),
|
||||
id=node.id,
|
||||
)
|
||||
|
||||
self.query(query)
|
||||
|
||||
# insert relationships
|
||||
for edge in doc.relationships:
|
||||
edge.source.properties["id"] = edge.source.id
|
||||
edge.target.properties["id"] = edge.target.id
|
||||
inputs = {
|
||||
"f_label": AGEGraph.clean_graph_labels(edge.source.type),
|
||||
"f_properties": self._format_properties(edge.source.properties),
|
||||
"t_label": AGEGraph.clean_graph_labels(edge.target.type),
|
||||
"t_properties": self._format_properties(edge.target.properties),
|
||||
"r_label": AGEGraph.clean_graph_labels(edge.type).upper(),
|
||||
"r_properties": self._format_properties(edge.properties),
|
||||
}
|
||||
|
||||
query = edge_insert_query.format(**inputs)
|
||||
self.query(query)
|
||||
@@ -0,0 +1,182 @@
|
||||
import os
|
||||
from math import ceil
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
class ArangoGraph:
|
||||
"""ArangoDB wrapper for graph operations.
|
||||
|
||||
*Security note*: Make sure that the database connection uses credentials
|
||||
that are narrowly-scoped to only include necessary permissions.
|
||||
Failure to do so may result in data corruption or loss, since the calling
|
||||
code may attempt commands that would result in deletion, mutation
|
||||
of data if appropriately prompted or reading sensitive data if such
|
||||
data is present in the database.
|
||||
The best way to guard against such negative outcomes is to (as appropriate)
|
||||
limit the permissions granted to the credentials used with this tool.
|
||||
|
||||
See https://python.langchain.com/docs/security for more information.
|
||||
"""
|
||||
|
||||
def __init__(self, db: Any) -> None:
|
||||
"""Create a new ArangoDB graph wrapper instance."""
|
||||
self.set_db(db)
|
||||
self.set_schema()
|
||||
|
||||
@property
|
||||
def db(self) -> Any:
|
||||
return self.__db
|
||||
|
||||
@property
|
||||
def schema(self) -> Dict[str, Any]:
|
||||
return self.__schema
|
||||
|
||||
def set_db(self, db: Any) -> None:
|
||||
from arango.database import Database
|
||||
|
||||
if not isinstance(db, Database):
|
||||
msg = "**db** parameter must inherit from arango.database.Database"
|
||||
raise TypeError(msg)
|
||||
|
||||
self.__db: Database = db
|
||||
self.set_schema()
|
||||
|
||||
def set_schema(self, schema: Optional[Dict[str, Any]] = None) -> None:
|
||||
"""
|
||||
Set the schema of the ArangoDB Database.
|
||||
Auto-generates Schema if **schema** is None.
|
||||
"""
|
||||
self.__schema = self.generate_schema() if schema is None else schema
|
||||
|
||||
def generate_schema(
|
||||
self, sample_ratio: float = 0
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Generates the schema of the ArangoDB Database and returns it
|
||||
User can specify a **sample_ratio** (0 to 1) to determine the
|
||||
ratio of documents/edges used (in relation to the Collection size)
|
||||
to render each Collection Schema.
|
||||
"""
|
||||
if not 0 <= sample_ratio <= 1:
|
||||
raise ValueError("**sample_ratio** value must be in between 0 to 1")
|
||||
|
||||
# Stores the Edge Relationships between each ArangoDB Document Collection
|
||||
graph_schema: List[Dict[str, Any]] = [
|
||||
{"graph_name": g["name"], "edge_definitions": g["edge_definitions"]}
|
||||
for g in self.db.graphs()
|
||||
]
|
||||
|
||||
# Stores the schema of every ArangoDB Document/Edge collection
|
||||
collection_schema: List[Dict[str, Any]] = []
|
||||
|
||||
for collection in self.db.collections():
|
||||
if collection["system"]:
|
||||
continue
|
||||
|
||||
# Extract collection name, type, and size
|
||||
col_name: str = collection["name"]
|
||||
col_type: str = collection["type"]
|
||||
col_size: int = self.db.collection(col_name).count()
|
||||
|
||||
# Skip collection if empty
|
||||
if col_size == 0:
|
||||
continue
|
||||
|
||||
# Set number of ArangoDB documents/edges to retrieve
|
||||
limit_amount = ceil(sample_ratio * col_size) or 1
|
||||
|
||||
aql = f"""
|
||||
FOR doc in `{col_name}`
|
||||
LIMIT {limit_amount}
|
||||
RETURN doc
|
||||
"""
|
||||
|
||||
doc: Dict[str, Any]
|
||||
properties: List[Dict[str, str]] = []
|
||||
for doc in self.__db.aql.execute(aql):
|
||||
for key, value in doc.items():
|
||||
properties.append({"name": key, "type": type(value).__name__})
|
||||
|
||||
collection_schema.append(
|
||||
{
|
||||
"collection_name": col_name,
|
||||
"collection_type": col_type,
|
||||
f"{col_type}_properties": properties,
|
||||
f"example_{col_type}": doc,
|
||||
}
|
||||
)
|
||||
|
||||
return {"Graph Schema": graph_schema, "Collection Schema": collection_schema}
|
||||
|
||||
def query(
|
||||
self, query: str, top_k: Optional[int] = None, **kwargs: Any
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Query the ArangoDB database."""
|
||||
import itertools
|
||||
|
||||
cursor = self.__db.aql.execute(query, **kwargs)
|
||||
return [doc for doc in itertools.islice(cursor, top_k)]
|
||||
|
||||
@classmethod
|
||||
def from_db_credentials(
|
||||
cls,
|
||||
url: Optional[str] = None,
|
||||
dbname: Optional[str] = None,
|
||||
username: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
) -> Any:
|
||||
"""Convenience constructor that builds Arango DB from credentials.
|
||||
|
||||
Args:
|
||||
url: Arango DB url. Can be passed in as named arg or set as environment
|
||||
var ``ARANGODB_URL``. Defaults to "http://localhost:8529".
|
||||
dbname: Arango DB name. Can be passed in as named arg or set as
|
||||
environment var ``ARANGODB_DBNAME``. Defaults to "_system".
|
||||
username: Can be passed in as named arg or set as environment var
|
||||
``ARANGODB_USERNAME``. Defaults to "root".
|
||||
password: Can be passed ni as named arg or set as environment var
|
||||
``ARANGODB_PASSWORD``. Defaults to "".
|
||||
|
||||
Returns:
|
||||
An arango.database.StandardDatabase.
|
||||
"""
|
||||
db = get_arangodb_client(
|
||||
url=url, dbname=dbname, username=username, password=password
|
||||
)
|
||||
return cls(db)
|
||||
|
||||
|
||||
def get_arangodb_client(
|
||||
url: Optional[str] = None,
|
||||
dbname: Optional[str] = None,
|
||||
username: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
) -> Any:
|
||||
"""Get the Arango DB client from credentials.
|
||||
|
||||
Args:
|
||||
url: Arango DB url. Can be passed in as named arg or set as environment
|
||||
var ``ARANGODB_URL``. Defaults to "http://localhost:8529".
|
||||
dbname: Arango DB name. Can be passed in as named arg or set as
|
||||
environment var ``ARANGODB_DBNAME``. Defaults to "_system".
|
||||
username: Can be passed in as named arg or set as environment var
|
||||
``ARANGODB_USERNAME``. Defaults to "root".
|
||||
password: Can be passed ni as named arg or set as environment var
|
||||
``ARANGODB_PASSWORD``. Defaults to "".
|
||||
|
||||
Returns:
|
||||
An arango.database.StandardDatabase.
|
||||
"""
|
||||
try:
|
||||
from arango import ArangoClient
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Unable to import arango, please install with `pip install python-arango`."
|
||||
) from e
|
||||
|
||||
_url: str = url or os.environ.get("ARANGODB_URL", "http://localhost:8529")
|
||||
_dbname: str = dbname or os.environ.get("ARANGODB_DBNAME", "_system")
|
||||
_username: str = username or os.environ.get("ARANGODB_USERNAME", "root")
|
||||
_password: str = password or os.environ.get("ARANGODB_PASSWORD", "")
|
||||
|
||||
return ArangoClient(_url).db(_dbname, _username, _password, verify=True)
|
||||
@@ -0,0 +1,201 @@
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
|
||||
from langchain_community.graphs.graph_document import GraphDocument
|
||||
from langchain_community.graphs.graph_store import GraphStore
|
||||
|
||||
node_properties_query = """
|
||||
MATCH (n)
|
||||
WITH keys(n) as keys, labels(n) AS labels
|
||||
WITH CASE WHEN keys = [] THEN [NULL] ELSE keys END AS keys, labels
|
||||
UNWIND labels AS label
|
||||
UNWIND keys AS key
|
||||
WITH label, collect(DISTINCT key) AS keys
|
||||
RETURN {label:label, keys:keys} AS output
|
||||
"""
|
||||
|
||||
rel_properties_query = """
|
||||
MATCH ()-[r]->()
|
||||
WITH keys(r) as keys, type(r) AS types
|
||||
WITH CASE WHEN keys = [] THEN [NULL] ELSE keys END AS keys, types
|
||||
UNWIND types AS type
|
||||
UNWIND keys AS key WITH type,
|
||||
collect(DISTINCT key) AS keys
|
||||
RETURN {types:type, keys:keys} AS output
|
||||
"""
|
||||
|
||||
rel_query = """
|
||||
MATCH (n)-[r]->(m)
|
||||
UNWIND labels(n) as src_label
|
||||
UNWIND labels(m) as dst_label
|
||||
UNWIND type(r) as rel_type
|
||||
RETURN DISTINCT {start: src_label, type: rel_type, end: dst_label} AS output
|
||||
"""
|
||||
|
||||
|
||||
class FalkorDBGraph(GraphStore):
|
||||
"""FalkorDB wrapper for graph operations.
|
||||
|
||||
*Security note*: Make sure that the database connection uses credentials
|
||||
that are narrowly-scoped to only include necessary permissions.
|
||||
Failure to do so may result in data corruption or loss, since the calling
|
||||
code may attempt commands that would result in deletion, mutation
|
||||
of data if appropriately prompted or reading sensitive data if such
|
||||
data is present in the database.
|
||||
The best way to guard against such negative outcomes is to (as appropriate)
|
||||
limit the permissions granted to the credentials used with this tool.
|
||||
|
||||
See https://python.langchain.com/docs/security for more information.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
database: str,
|
||||
host: str = "localhost",
|
||||
port: int = 6379,
|
||||
username: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
ssl: bool = False,
|
||||
) -> None:
|
||||
"""Create a new FalkorDB graph wrapper instance."""
|
||||
try:
|
||||
self.__init_falkordb_connection(
|
||||
database, host, port, username, password, ssl
|
||||
)
|
||||
|
||||
except ImportError:
|
||||
try:
|
||||
# Falls back to using the redis package just for backwards compatibility
|
||||
self.__init_redis_connection(
|
||||
database, host, port, username, password, ssl
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import falkordb python package. "
|
||||
"Please install it with `pip install falkordb`."
|
||||
)
|
||||
|
||||
self.schema: str = ""
|
||||
self.structured_schema: Dict[str, Any] = {}
|
||||
|
||||
try:
|
||||
self.refresh_schema()
|
||||
except Exception as e:
|
||||
raise ValueError(f"Could not refresh schema. Error: {e}")
|
||||
|
||||
def __init_falkordb_connection(
|
||||
self,
|
||||
database: str,
|
||||
host: str = "localhost",
|
||||
port: int = 6379,
|
||||
username: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
ssl: bool = False,
|
||||
) -> None:
|
||||
from falkordb import FalkorDB
|
||||
|
||||
try:
|
||||
self._driver = FalkorDB(
|
||||
host=host, port=port, username=username, password=password, ssl=ssl
|
||||
)
|
||||
except Exception as e:
|
||||
raise ConnectionError(f"Failed to connect to FalkorDB: {e}")
|
||||
|
||||
self._graph = self._driver.select_graph(database)
|
||||
|
||||
@deprecated("0.0.31", alternative="__init_falkordb_connection")
|
||||
def __init_redis_connection(
|
||||
self,
|
||||
database: str,
|
||||
host: str = "localhost",
|
||||
port: int = 6379,
|
||||
username: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
ssl: bool = False,
|
||||
) -> None:
|
||||
import redis
|
||||
from redis.commands.graph import Graph
|
||||
|
||||
# show deprecation warning
|
||||
warnings.warn(
|
||||
"Using the redis package is deprecated. "
|
||||
"Please use the falkordb package instead, "
|
||||
"install it with `pip install falkordb`.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
self._driver = redis.Redis(
|
||||
host=host, port=port, username=username, password=password, ssl=ssl
|
||||
)
|
||||
|
||||
self._graph = Graph(self._driver, database)
|
||||
|
||||
@property
|
||||
def get_schema(self) -> str:
|
||||
"""Returns the schema of the FalkorDB database"""
|
||||
return self.schema
|
||||
|
||||
@property
|
||||
def get_structured_schema(self) -> Dict[str, Any]:
|
||||
"""Returns the structured schema of the Graph"""
|
||||
return self.structured_schema
|
||||
|
||||
def refresh_schema(self) -> None:
|
||||
"""Refreshes the schema of the FalkorDB database"""
|
||||
node_properties: List[Any] = self.query(node_properties_query)
|
||||
rel_properties: List[Any] = self.query(rel_properties_query)
|
||||
relationships: List[Any] = self.query(rel_query)
|
||||
|
||||
self.structured_schema = {
|
||||
"node_props": {el[0]["label"]: el[0]["keys"] for el in node_properties},
|
||||
"rel_props": {el[0]["types"]: el[0]["keys"] for el in rel_properties},
|
||||
"relationships": [el[0] for el in relationships],
|
||||
}
|
||||
|
||||
self.schema = (
|
||||
f"Node properties: {node_properties}\n"
|
||||
f"Relationships properties: {rel_properties}\n"
|
||||
f"Relationships: {relationships}\n"
|
||||
)
|
||||
|
||||
def query(self, query: str, params: dict = {}) -> List[Dict[str, Any]]:
|
||||
"""Query FalkorDB database."""
|
||||
|
||||
try:
|
||||
data = self._graph.query(query, params)
|
||||
return data.result_set
|
||||
except Exception as e:
|
||||
raise ValueError(f"Generated Cypher Statement is not valid\n{e}")
|
||||
|
||||
def add_graph_documents(
|
||||
self, graph_documents: List[GraphDocument], include_source: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
Take GraphDocument as input as uses it to construct a graph.
|
||||
"""
|
||||
for document in graph_documents:
|
||||
# Import nodes
|
||||
for node in document.nodes:
|
||||
self.query(
|
||||
(
|
||||
f"MERGE (n:{node.type} {{id:'{node.id}'}}) "
|
||||
"SET n += $properties "
|
||||
"RETURN distinct 'done' AS result"
|
||||
),
|
||||
{"properties": node.properties},
|
||||
)
|
||||
|
||||
# Import relationships
|
||||
for rel in document.relationships:
|
||||
self.query(
|
||||
(
|
||||
f"MATCH (a:{rel.source.type} {{id:'{rel.source.id}'}}), "
|
||||
f"(b:{rel.target.type} {{id:'{rel.target.id}'}}) "
|
||||
f"MERGE (a)-[r:{(rel.type.replace(' ', '_').upper())}]->(b) "
|
||||
"SET r += $properties "
|
||||
"RETURN distinct 'done' AS result"
|
||||
),
|
||||
{"properties": rel.properties},
|
||||
)
|
||||
@@ -0,0 +1,51 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Union
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.load.serializable import Serializable
|
||||
from pydantic import Field
|
||||
|
||||
|
||||
class Node(Serializable):
|
||||
"""Represents a node in a graph with associated properties.
|
||||
|
||||
Attributes:
|
||||
id (Union[str, int]): A unique identifier for the node.
|
||||
type (str): The type or label of the node, default is "Node".
|
||||
properties (dict): Additional properties and metadata associated with the node.
|
||||
"""
|
||||
|
||||
id: Union[str, int]
|
||||
type: str = "Node"
|
||||
properties: dict = Field(default_factory=dict)
|
||||
|
||||
|
||||
class Relationship(Serializable):
|
||||
"""Represents a directed relationship between two nodes in a graph.
|
||||
|
||||
Attributes:
|
||||
source (Node): The source node of the relationship.
|
||||
target (Node): The target node of the relationship.
|
||||
type (str): The type of the relationship.
|
||||
properties (dict): Additional properties associated with the relationship.
|
||||
"""
|
||||
|
||||
source: Node
|
||||
target: Node
|
||||
type: str
|
||||
properties: dict = Field(default_factory=dict)
|
||||
|
||||
|
||||
class GraphDocument(Serializable):
|
||||
"""Represents a graph document consisting of nodes and relationships.
|
||||
|
||||
Attributes:
|
||||
nodes (List[Node]): A list of nodes in the graph.
|
||||
relationships (List[Relationship]): A list of relationships in the graph.
|
||||
source (Document): The document from which the graph information is derived.
|
||||
"""
|
||||
|
||||
nodes: List[Node]
|
||||
relationships: List[Relationship]
|
||||
source: Document
|
||||
@@ -0,0 +1,37 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from langchain_community.graphs.graph_document import GraphDocument
|
||||
|
||||
|
||||
class GraphStore:
|
||||
"""Abstract class for graph operations."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def get_schema(self) -> str:
|
||||
"""Return the schema of the Graph database"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def get_structured_schema(self) -> Dict[str, Any]:
|
||||
"""Return the schema of the Graph database"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def query(self, query: str, params: dict = {}) -> List[Dict[str, Any]]:
|
||||
"""Query the graph."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def refresh_schema(self) -> None:
|
||||
"""Refresh the graph schema information."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def add_graph_documents(
|
||||
self, graph_documents: List[GraphDocument], include_source: bool = False
|
||||
) -> None:
|
||||
"""Take GraphDocument as input as uses it to construct a graph."""
|
||||
pass
|
||||
@@ -0,0 +1,228 @@
|
||||
import hashlib
|
||||
import sys
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from langchain_core.utils import get_from_env
|
||||
|
||||
from langchain_community.graphs.graph_document import GraphDocument, Node, Relationship
|
||||
from langchain_community.graphs.graph_store import GraphStore
|
||||
|
||||
|
||||
class GremlinGraph(GraphStore):
|
||||
"""Gremlin wrapper for graph operations.
|
||||
|
||||
Parameters:
|
||||
url (Optional[str]): The URL of the Gremlin database server or env GREMLIN_URI
|
||||
username (Optional[str]): The collection-identifier like '/dbs/database/colls/graph'
|
||||
or env GREMLIN_USERNAME if none provided
|
||||
password (Optional[str]): The connection-key for database authentication
|
||||
or env GREMLIN_PASSWORD if none provided
|
||||
traversal_source (str): The traversal source to use for queries. Defaults to 'g'.
|
||||
message_serializer (Optional[Any]): The message serializer to use for requests.
|
||||
Defaults to serializer.GraphSONSerializersV2d0()
|
||||
include_edge_properties (bool): Whether to include edge properties in
|
||||
the gremlin graph schema. Defaults to False
|
||||
*Security note*: Make sure that the database connection uses credentials
|
||||
that are narrowly-scoped to only include necessary permissions.
|
||||
Failure to do so may result in data corruption or loss, since the calling
|
||||
code may attempt commands that would result in deletion, mutation
|
||||
of data if appropriately prompted or reading sensitive data if such
|
||||
data is present in the database.
|
||||
The best way to guard against such negative outcomes is to (as appropriate)
|
||||
limit the permissions granted to the credentials used with this tool.
|
||||
|
||||
See https://python.langchain.com/docs/security for more information.
|
||||
|
||||
*Implementation details*:
|
||||
The Gremlin queries are designed to work with Azure CosmosDB limitations
|
||||
"""
|
||||
|
||||
@property
|
||||
def get_structured_schema(self) -> Dict[str, Any]:
|
||||
return self.structured_schema
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
url: Optional[str] = None,
|
||||
username: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
traversal_source: str = "g",
|
||||
message_serializer: Optional[Any] = None,
|
||||
include_edge_properties: bool = False,
|
||||
) -> None:
|
||||
"""Create a new Gremlin graph wrapper instance."""
|
||||
try:
|
||||
import asyncio
|
||||
|
||||
from gremlin_python.driver import client, serializer
|
||||
|
||||
if sys.platform == "win32":
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install gremlin-python first: `pip3 install gremlinpython"
|
||||
)
|
||||
|
||||
self.client = client.Client(
|
||||
url=get_from_env("url", "GREMLIN_URI", url),
|
||||
traversal_source=traversal_source,
|
||||
username=get_from_env("username", "GREMLIN_USERNAME", username),
|
||||
password=get_from_env("password", "GREMLIN_PASSWORD", password),
|
||||
message_serializer=message_serializer
|
||||
if message_serializer
|
||||
else serializer.GraphSONSerializersV2d0(),
|
||||
)
|
||||
self.schema: str = ""
|
||||
self.include_edge_properties = include_edge_properties
|
||||
|
||||
@property
|
||||
def get_schema(self) -> str:
|
||||
"""Returns the schema of the Gremlin database"""
|
||||
if len(self.schema) == 0:
|
||||
self.refresh_schema()
|
||||
return self.schema
|
||||
|
||||
def refresh_schema(self) -> None:
|
||||
"""
|
||||
Refreshes the Gremlin graph schema information.
|
||||
"""
|
||||
vertex_schema = self.client.submit("g.V().label().dedup()").all().result()
|
||||
edge_schema = self.client.submit("g.E().label().dedup()").all().result()
|
||||
vertex_properties = (
|
||||
self.client.submit(
|
||||
"g.V().group().by(label).by(properties().label().dedup().fold())"
|
||||
)
|
||||
.all()
|
||||
.result()[0]
|
||||
)
|
||||
|
||||
self.structured_schema = {
|
||||
"vertex_labels": vertex_schema,
|
||||
"edge_labels": edge_schema,
|
||||
"vertice_props": vertex_properties,
|
||||
}
|
||||
|
||||
self.schema = "\n".join(
|
||||
[
|
||||
"Vertex labels are the following:",
|
||||
",".join(vertex_schema),
|
||||
"Edge labels are the following:",
|
||||
",".join(edge_schema),
|
||||
f"Vertices have following properties:\n{vertex_properties}",
|
||||
]
|
||||
)
|
||||
if self.include_edge_properties:
|
||||
edge_properties = (
|
||||
self.client.submit(
|
||||
"g.E().group().by(label)"
|
||||
".by(project('inVLabel', 'outVLabel','properties')"
|
||||
".by(inV().label()).by(outV().label()).by(properties().key().dedup()"
|
||||
".fold()).dedup().fold())"
|
||||
)
|
||||
.all()
|
||||
.result()[0]
|
||||
)
|
||||
self.structured_schema["edge_props"] = edge_properties
|
||||
self.schema += (
|
||||
f"\nEdges have the following properties, grouped by label and"
|
||||
f" the distinct inV and outV labels:\n {edge_properties}"
|
||||
)
|
||||
|
||||
def query(self, query: str, params: dict = {}) -> List[Dict[str, Any]]:
|
||||
q = self.client.submit(query)
|
||||
return q.all().result()
|
||||
|
||||
def add_graph_documents(
|
||||
self, graph_documents: List[GraphDocument], include_source: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
Take GraphDocument as input as uses it to construct a graph.
|
||||
"""
|
||||
node_cache: Dict[Union[str, int], Node] = {}
|
||||
for document in graph_documents:
|
||||
if include_source:
|
||||
# Create document vertex
|
||||
doc_props = {
|
||||
"page_content": document.source.page_content,
|
||||
"metadata": document.source.metadata,
|
||||
}
|
||||
doc_id = hashlib.md5(document.source.page_content.encode()).hexdigest()
|
||||
doc_node = self.add_node(
|
||||
Node(id=doc_id, type="Document", properties=doc_props), node_cache
|
||||
)
|
||||
|
||||
# Import nodes to vertices
|
||||
for n in document.nodes:
|
||||
node = self.add_node(n)
|
||||
if include_source:
|
||||
# Add Edge to document for each node
|
||||
self.add_edge(
|
||||
Relationship(
|
||||
type="contains information about",
|
||||
source=doc_node,
|
||||
target=node,
|
||||
properties={},
|
||||
)
|
||||
)
|
||||
self.add_edge(
|
||||
Relationship(
|
||||
type="is extracted from",
|
||||
source=node,
|
||||
target=doc_node,
|
||||
properties={},
|
||||
)
|
||||
)
|
||||
|
||||
# Edges
|
||||
for el in document.relationships:
|
||||
# Find or create the source vertex
|
||||
self.add_node(el.source, node_cache)
|
||||
# Find or create the target vertex
|
||||
self.add_node(el.target, node_cache)
|
||||
# Find or create the edge
|
||||
self.add_edge(el)
|
||||
|
||||
def build_vertex_query(self, node: Node) -> str:
|
||||
base_query = (
|
||||
f"g.V().has('id','{node.id}').fold()"
|
||||
+ f".coalesce(unfold(),addV('{node.type}')"
|
||||
+ f".property('id','{node.id}')"
|
||||
+ f".property('type','{node.type}')"
|
||||
)
|
||||
for key, value in node.properties.items():
|
||||
base_query += f".property('{key}', '{value}')"
|
||||
|
||||
return base_query + ")"
|
||||
|
||||
def build_edge_query(self, relationship: Relationship) -> str:
|
||||
source_query = f".has('id','{relationship.source.id}')"
|
||||
target_query = f".has('id','{relationship.target.id}')"
|
||||
|
||||
base_query = f""""g.V(){source_query}.as('a')
|
||||
.V(){target_query}.as('b')
|
||||
.choose(
|
||||
__.inE('{relationship.type}').where(outV().as('a')),
|
||||
__.identity(),
|
||||
__.addE('{relationship.type}').from('a').to('b')
|
||||
)
|
||||
""".replace("\n", "").replace("\t", "")
|
||||
for key, value in relationship.properties.items():
|
||||
base_query += f".property('{key}', '{value}')"
|
||||
|
||||
return base_query
|
||||
|
||||
def add_node(self, node: Node, node_cache: dict = {}) -> Node:
|
||||
# if properties does not have label, add type as label
|
||||
if "label" not in node.properties:
|
||||
node.properties["label"] = node.type
|
||||
if node.id in node_cache:
|
||||
return node_cache[node.id]
|
||||
else:
|
||||
query = self.build_vertex_query(node)
|
||||
_ = self.client.submit(query).all().result()[0]
|
||||
node_cache[node.id] = node
|
||||
return node
|
||||
|
||||
def add_edge(self, relationship: Relationship) -> Any:
|
||||
query = self.build_edge_query(relationship)
|
||||
return self.client.submit(query).all().result()
|
||||
@@ -0,0 +1,74 @@
|
||||
from typing import Any, Dict, List
|
||||
|
||||
|
||||
class HugeGraph:
|
||||
"""HugeGraph wrapper for graph operations.
|
||||
|
||||
*Security note*: Make sure that the database connection uses credentials
|
||||
that are narrowly-scoped to only include necessary permissions.
|
||||
Failure to do so may result in data corruption or loss, since the calling
|
||||
code may attempt commands that would result in deletion, mutation
|
||||
of data if appropriately prompted or reading sensitive data if such
|
||||
data is present in the database.
|
||||
The best way to guard against such negative outcomes is to (as appropriate)
|
||||
limit the permissions granted to the credentials used with this tool.
|
||||
|
||||
See https://python.langchain.com/docs/security for more information.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
username: str = "default",
|
||||
password: str = "default",
|
||||
address: str = "127.0.0.1",
|
||||
port: int = 8081,
|
||||
graph: str = "hugegraph",
|
||||
) -> None:
|
||||
"""Create a new HugeGraph wrapper instance."""
|
||||
try:
|
||||
from hugegraph.connection import PyHugeGraph
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install HugeGraph Python client first: "
|
||||
"`pip3 install hugegraph-python`"
|
||||
)
|
||||
|
||||
self.username = username
|
||||
self.password = password
|
||||
self.address = address
|
||||
self.port = port
|
||||
self.graph = graph
|
||||
self.client = PyHugeGraph(
|
||||
address, port, user=username, pwd=password, graph=graph
|
||||
)
|
||||
self.schema = ""
|
||||
# Set schema
|
||||
try:
|
||||
self.refresh_schema()
|
||||
except Exception as e:
|
||||
raise ValueError(f"Could not refresh schema. Error: {e}")
|
||||
|
||||
@property
|
||||
def get_schema(self) -> str:
|
||||
"""Returns the schema of the HugeGraph database"""
|
||||
return self.schema
|
||||
|
||||
def refresh_schema(self) -> None:
|
||||
"""
|
||||
Refreshes the HugeGraph schema information.
|
||||
"""
|
||||
schema = self.client.schema()
|
||||
vertex_schema = schema.getVertexLabels()
|
||||
edge_schema = schema.getEdgeLabels()
|
||||
relationships = schema.getRelations()
|
||||
|
||||
self.schema = (
|
||||
f"Node properties: {vertex_schema}\n"
|
||||
f"Edge properties: {edge_schema}\n"
|
||||
f"Relationships: {relationships}\n"
|
||||
)
|
||||
|
||||
def query(self, query: str) -> List[Dict[str, Any]]:
|
||||
g = self.client.gremlin()
|
||||
res = g.exec(query)
|
||||
return res["data"]
|
||||
@@ -0,0 +1,99 @@
|
||||
from typing import Optional, Type
|
||||
|
||||
|
||||
from pydantic import BaseModel
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.prompts import BasePromptTemplate
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
|
||||
from langchain_community.graphs import NetworkxEntityGraph
|
||||
from langchain_community.graphs.networkx_graph import KG_TRIPLE_DELIMITER
|
||||
from langchain_community.graphs.networkx_graph import parse_triples
|
||||
|
||||
# flake8: noqa
|
||||
|
||||
_DEFAULT_KNOWLEDGE_TRIPLE_EXTRACTION_TEMPLATE = (
|
||||
"You are a networked intelligence helping a human track knowledge triples"
|
||||
" about all relevant people, things, concepts, etc. and integrating"
|
||||
" them with your knowledge stored within your weights"
|
||||
" as well as that stored in a knowledge graph."
|
||||
" Extract all of the knowledge triples from the text."
|
||||
" A knowledge triple is a clause that contains a subject, a predicate,"
|
||||
" and an object. The subject is the entity being described,"
|
||||
" the predicate is the property of the subject that is being"
|
||||
" described, and the object is the value of the property.\n\n"
|
||||
"EXAMPLE\n"
|
||||
"It's a state in the US. It's also the number 1 producer of gold in the US.\n\n"
|
||||
f"Output: (Nevada, is a, state){KG_TRIPLE_DELIMITER}(Nevada, is in, US)"
|
||||
f"{KG_TRIPLE_DELIMITER}(Nevada, is the number 1 producer of, gold)\n"
|
||||
"END OF EXAMPLE\n\n"
|
||||
"EXAMPLE\n"
|
||||
"I'm going to the store.\n\n"
|
||||
"Output: NONE\n"
|
||||
"END OF EXAMPLE\n\n"
|
||||
"EXAMPLE\n"
|
||||
"Oh huh. I know Descartes likes to drive antique scooters and play the mandolin.\n"
|
||||
f"Output: (Descartes, likes to drive, antique scooters){KG_TRIPLE_DELIMITER}(Descartes, plays, mandolin)\n"
|
||||
"END OF EXAMPLE\n\n"
|
||||
"EXAMPLE\n"
|
||||
"{text}"
|
||||
"Output:"
|
||||
)
|
||||
|
||||
KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT = PromptTemplate(
|
||||
input_variables=["text"],
|
||||
template=_DEFAULT_KNOWLEDGE_TRIPLE_EXTRACTION_TEMPLATE,
|
||||
)
|
||||
|
||||
|
||||
class GraphIndexCreator(BaseModel):
|
||||
"""Functionality to create graph index."""
|
||||
|
||||
llm: Optional[BaseLanguageModel] = None
|
||||
graph_type: Type[NetworkxEntityGraph] = NetworkxEntityGraph
|
||||
|
||||
def from_text(
|
||||
self, text: str, prompt: BasePromptTemplate = KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT
|
||||
) -> NetworkxEntityGraph:
|
||||
"""Create graph index from text."""
|
||||
if self.llm is None:
|
||||
raise ValueError("llm should not be None")
|
||||
graph = self.graph_type()
|
||||
# Temporary local scoped import while community does not depend on
|
||||
# langchain explicitly
|
||||
try:
|
||||
from langchain_classic.chains import LLMChain
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install langchain to use this functionality. "
|
||||
"You can install it with `pip install langchain`."
|
||||
)
|
||||
chain = LLMChain(llm=self.llm, prompt=prompt)
|
||||
output = chain.predict(text=text)
|
||||
knowledge = parse_triples(output)
|
||||
for triple in knowledge:
|
||||
graph.add_triple(triple)
|
||||
return graph
|
||||
|
||||
async def afrom_text(
|
||||
self, text: str, prompt: BasePromptTemplate = KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT
|
||||
) -> NetworkxEntityGraph:
|
||||
"""Create graph index from text asynchronously."""
|
||||
if self.llm is None:
|
||||
raise ValueError("llm should not be None")
|
||||
graph = self.graph_type()
|
||||
# Temporary local scoped import while community does not depend on
|
||||
# langchain explicitly
|
||||
try:
|
||||
from langchain_classic.chains import LLMChain
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install langchain to use this functionality. "
|
||||
"You can install it with `pip install langchain`."
|
||||
)
|
||||
chain = LLMChain(llm=self.llm, prompt=prompt)
|
||||
output = await chain.apredict(text=text)
|
||||
knowledge = parse_triples(output)
|
||||
for triple in knowledge:
|
||||
graph.add_triple(triple)
|
||||
return graph
|
||||
264
venv/Lib/site-packages/langchain_community/graphs/kuzu_graph.py
Normal file
264
venv/Lib/site-packages/langchain_community/graphs/kuzu_graph.py
Normal file
@@ -0,0 +1,264 @@
|
||||
from hashlib import md5
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from langchain_community.graphs.graph_document import GraphDocument, Relationship
|
||||
|
||||
|
||||
class KuzuGraph:
|
||||
"""Kùzu wrapper for graph operations.
|
||||
|
||||
*Security note*: Make sure that the database connection uses credentials
|
||||
that are narrowly-scoped to only include necessary permissions.
|
||||
Failure to do so may result in data corruption or loss, since the calling
|
||||
code may attempt commands that would result in deletion, mutation
|
||||
of data if appropriately prompted or reading sensitive data if such
|
||||
data is present in the database.
|
||||
The best way to guard against such negative outcomes is to (as appropriate)
|
||||
limit the permissions granted to the credentials used with this tool.
|
||||
|
||||
See https://python.langchain.com/docs/security for more information.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, db: Any, database: str = "kuzu", allow_dangerous_requests: bool = False
|
||||
) -> None:
|
||||
"""Initializes the Kùzu graph database connection."""
|
||||
|
||||
if allow_dangerous_requests is not True:
|
||||
raise ValueError(
|
||||
"The KuzuGraph class is a powerful tool that can be used to execute "
|
||||
"arbitrary queries on the database. To enable this functionality, "
|
||||
"set the `allow_dangerous_requests` parameter to `True` when "
|
||||
"constructing the KuzuGraph object."
|
||||
)
|
||||
|
||||
try:
|
||||
import kuzu
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import Kùzu python package."
|
||||
"Please install Kùzu with `pip install kuzu`."
|
||||
)
|
||||
self.db = db
|
||||
self.conn = kuzu.Connection(self.db)
|
||||
self.database = database
|
||||
self.refresh_schema()
|
||||
|
||||
@property
|
||||
def get_schema(self) -> str:
|
||||
"""Returns the schema of the Kùzu database"""
|
||||
return self.schema
|
||||
|
||||
def query(self, query: str, params: dict = {}) -> List[Dict[str, Any]]:
|
||||
"""Query Kùzu database"""
|
||||
result = self.conn.execute(query, params)
|
||||
column_names = result.get_column_names()
|
||||
return_list = []
|
||||
while result.has_next():
|
||||
row = result.get_next()
|
||||
return_list.append(dict(zip(column_names, row)))
|
||||
return return_list
|
||||
|
||||
def refresh_schema(self) -> None:
|
||||
"""Refreshes the Kùzu graph schema information"""
|
||||
node_properties = []
|
||||
node_table_names = self.conn._get_node_table_names()
|
||||
for table_name in node_table_names:
|
||||
current_table_schema = {"properties": [], "label": table_name}
|
||||
properties = self.conn._get_node_property_names(table_name)
|
||||
for property_name in properties:
|
||||
property_type = properties[property_name]["type"]
|
||||
list_type_flag = ""
|
||||
if properties[property_name]["dimension"] > 0:
|
||||
if "shape" in properties[property_name]:
|
||||
for s in properties[property_name]["shape"]:
|
||||
list_type_flag += f"[{s}]"
|
||||
else:
|
||||
for i in range(properties[property_name]["dimension"]):
|
||||
list_type_flag += "[]"
|
||||
property_type += list_type_flag
|
||||
current_table_schema["properties"].append(
|
||||
(
|
||||
property_name,
|
||||
property_type,
|
||||
)
|
||||
)
|
||||
node_properties.append(current_table_schema)
|
||||
|
||||
relationships = []
|
||||
rel_tables = self.conn._get_rel_table_names()
|
||||
for table in rel_tables:
|
||||
relationships.append(
|
||||
f"(:{table['src']})-[:{table['name']}]->(:{table['dst']})"
|
||||
)
|
||||
|
||||
rel_properties = []
|
||||
for table in rel_tables:
|
||||
table_name = table["name"]
|
||||
current_table_schema = {"properties": [], "label": table_name}
|
||||
query_result = self.conn.execute(
|
||||
f"CALL table_info('{table_name}') RETURN *;"
|
||||
)
|
||||
while query_result.has_next():
|
||||
row = query_result.get_next()
|
||||
prop_name = row[1]
|
||||
prop_type = row[2]
|
||||
current_table_schema["properties"].append((prop_name, prop_type))
|
||||
rel_properties.append(current_table_schema)
|
||||
|
||||
self.schema = (
|
||||
f"Node properties: {node_properties}\n"
|
||||
f"Relationships properties: {rel_properties}\n"
|
||||
f"Relationships: {relationships}\n"
|
||||
)
|
||||
|
||||
def _create_chunk_node_table(self) -> None:
|
||||
self.conn.execute(
|
||||
"""
|
||||
CREATE NODE TABLE IF NOT EXISTS Chunk (
|
||||
id STRING,
|
||||
text STRING,
|
||||
type STRING,
|
||||
PRIMARY KEY(id)
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
def _create_entity_node_table(self, node_label: str) -> None:
|
||||
self.conn.execute(
|
||||
f"""
|
||||
CREATE NODE TABLE IF NOT EXISTS {node_label} (
|
||||
id STRING,
|
||||
type STRING,
|
||||
PRIMARY KEY(id)
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
def _create_entity_relationship_table(self, rel: Relationship) -> None:
|
||||
self.conn.execute(
|
||||
f"""
|
||||
CREATE REL TABLE IF NOT EXISTS {rel.type} (
|
||||
FROM {rel.source.type} TO {rel.target.type}
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
def add_graph_documents(
|
||||
self,
|
||||
graph_documents: List[GraphDocument],
|
||||
allowed_relationships: List[Tuple[str, str, str]],
|
||||
include_source: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Adds a list of `GraphDocument` objects that represent nodes and relationships
|
||||
in a graph to a Kùzu backend.
|
||||
|
||||
Parameters:
|
||||
- graph_documents (List[GraphDocument]): A list of `GraphDocument` objects
|
||||
that contain the nodes and relationships to be added to the graph. Each
|
||||
`GraphDocument` should encapsulate the structure of part of the graph,
|
||||
including nodes, relationships, and the source document information.
|
||||
|
||||
- allowed_relationships (List[Tuple[str, str, str]]): A list of allowed
|
||||
relationships that exist in the graph. Each tuple contains three elements:
|
||||
the source node type, the relationship type, and the target node type.
|
||||
Required for Kùzu, as the names of the relationship tables that need to
|
||||
pre-exist are derived from these tuples.
|
||||
|
||||
- include_source (bool): If True, stores the source document
|
||||
and links it to nodes in the graph using the `MENTIONS` relationship.
|
||||
This is useful for tracing back the origin of data. Merges source
|
||||
documents based on the `id` property from the source document metadata
|
||||
if available; otherwise it calculates the MD5 hash of `page_content`
|
||||
for merging process. Defaults to False.
|
||||
"""
|
||||
# Get unique node labels in the graph documents
|
||||
node_labels = list(
|
||||
{node.type for document in graph_documents for node in document.nodes}
|
||||
)
|
||||
|
||||
for document in graph_documents:
|
||||
# Add chunk nodes and create source document relationships if include_source
|
||||
# is True
|
||||
if include_source:
|
||||
self._create_chunk_node_table()
|
||||
if not document.source.metadata.get("id"):
|
||||
# Add a unique id to each document chunk via an md5 hash
|
||||
document.source.metadata["id"] = md5(
|
||||
document.source.page_content.encode("utf-8")
|
||||
).hexdigest()
|
||||
|
||||
self.conn.execute(
|
||||
f"""
|
||||
MERGE (c:Chunk {{id: $id}})
|
||||
SET c.text = $text,
|
||||
c.type = "text_chunk"
|
||||
""", # noqa: F541
|
||||
parameters={
|
||||
"id": document.source.metadata["id"],
|
||||
"text": document.source.page_content,
|
||||
},
|
||||
)
|
||||
|
||||
for node_label in node_labels:
|
||||
self._create_entity_node_table(node_label)
|
||||
|
||||
# Add entity nodes from data
|
||||
for node in document.nodes:
|
||||
self.conn.execute(
|
||||
f"""
|
||||
MERGE (e:{node.type} {{id: $id}})
|
||||
SET e.type = "entity"
|
||||
""",
|
||||
parameters={"id": node.id},
|
||||
)
|
||||
if include_source:
|
||||
# If include_source is True, we need to create a relationship table
|
||||
# between the chunk nodes and the entity nodes
|
||||
self._create_chunk_node_table()
|
||||
ddl = "CREATE REL TABLE GROUP IF NOT EXISTS MENTIONS ("
|
||||
table_names = []
|
||||
for node_label in node_labels:
|
||||
table_names.append(f"FROM Chunk TO {node_label}")
|
||||
table_names = list(set(table_names))
|
||||
ddl += ", ".join(table_names)
|
||||
# Add common properties for all the tables here
|
||||
ddl += ", label STRING, triplet_source_id STRING)"
|
||||
if ddl:
|
||||
self.conn.execute(ddl)
|
||||
|
||||
# Only allow relationships that exist in the schema
|
||||
if node.type in node_labels:
|
||||
self.conn.execute(
|
||||
f"""
|
||||
MATCH (c:Chunk {{id: $id}}),
|
||||
(e:{node.type} {{id: $node_id}})
|
||||
MERGE (c)-[m:MENTIONS]->(e)
|
||||
SET m.triplet_source_id = $id
|
||||
""",
|
||||
parameters={
|
||||
"id": document.source.metadata["id"],
|
||||
"node_id": node.id,
|
||||
},
|
||||
)
|
||||
|
||||
# Add entity relationships
|
||||
for rel in document.relationships:
|
||||
self._create_entity_relationship_table(rel)
|
||||
# Create relationship
|
||||
source_label = rel.source.type
|
||||
source_id = rel.source.id
|
||||
target_label = rel.target.type
|
||||
target_id = rel.target.id
|
||||
self.conn.execute(
|
||||
f"""
|
||||
MATCH (e1:{source_label} {{id: $source_id}}),
|
||||
(e2:{target_label} {{id: $target_id}})
|
||||
MERGE (e1)-[:{rel.type}]->(e2)
|
||||
""",
|
||||
parameters={
|
||||
"source_id": source_id,
|
||||
"target_id": target_id,
|
||||
},
|
||||
)
|
||||
@@ -0,0 +1,525 @@
|
||||
import logging
|
||||
from hashlib import md5
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
|
||||
from langchain_community.graphs.graph_document import GraphDocument, Node, Relationship
|
||||
from langchain_community.graphs.graph_store import GraphStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
BASE_ENTITY_LABEL = "__Entity__"
|
||||
|
||||
SCHEMA_QUERY = """
|
||||
SHOW SCHEMA INFO
|
||||
"""
|
||||
|
||||
NODE_PROPERTIES_QUERY = """
|
||||
CALL schema.node_type_properties()
|
||||
YIELD nodeType AS label, propertyName AS property, propertyTypes AS type
|
||||
WITH label AS nodeLabels, collect({key: property, types: type}) AS properties
|
||||
RETURN {labels: nodeLabels, properties: properties} AS output
|
||||
"""
|
||||
|
||||
REL_QUERY = """
|
||||
MATCH (n)-[e]->(m)
|
||||
WITH DISTINCT
|
||||
labels(n) AS start_node_labels,
|
||||
type(e) AS rel_type,
|
||||
labels(m) AS end_node_labels,
|
||||
e,
|
||||
keys(e) AS properties
|
||||
UNWIND CASE WHEN size(properties) > 0 THEN properties ELSE [null] END AS prop
|
||||
WITH
|
||||
start_node_labels,
|
||||
rel_type,
|
||||
end_node_labels,
|
||||
CASE WHEN prop IS NULL THEN [] ELSE [prop, valueType(e[prop])] END AS property_info
|
||||
RETURN
|
||||
start_node_labels,
|
||||
rel_type,
|
||||
end_node_labels,
|
||||
COLLECT(DISTINCT CASE
|
||||
WHEN property_info <> []
|
||||
THEN property_info
|
||||
ELSE null END) AS properties_info
|
||||
"""
|
||||
|
||||
NODE_IMPORT_QUERY = """
|
||||
UNWIND $data AS row
|
||||
CALL merge.node(row.label, row.properties, {}, {})
|
||||
YIELD node
|
||||
RETURN distinct 'done' AS result
|
||||
"""
|
||||
|
||||
REL_NODES_IMPORT_QUERY = """
|
||||
UNWIND $data AS row
|
||||
MERGE (source {id: row.source_id})
|
||||
MERGE (target {id: row.target_id})
|
||||
RETURN distinct 'done' AS result
|
||||
"""
|
||||
|
||||
REL_IMPORT_QUERY = """
|
||||
UNWIND $data AS row
|
||||
MATCH (source {id: row.source_id})
|
||||
MATCH (target {id: row.target_id})
|
||||
WITH source, target, row
|
||||
CALL merge.relationship(source, row.type, {}, {}, target, {})
|
||||
YIELD rel
|
||||
RETURN distinct 'done' AS result
|
||||
"""
|
||||
|
||||
INCLUDE_DOCS_QUERY = """
|
||||
MERGE (d:Document {id:$document.metadata.id})
|
||||
SET d.content = $document.page_content
|
||||
SET d += $document.metadata
|
||||
RETURN distinct 'done' AS result
|
||||
"""
|
||||
|
||||
INCLUDE_DOCS_SOURCE_QUERY = """
|
||||
UNWIND $data AS row
|
||||
MATCH (source {id: row.source_id}), (d:Document {id: $document.metadata.id})
|
||||
MERGE (d)-[:MENTIONS]->(source)
|
||||
RETURN distinct 'done' AS result
|
||||
"""
|
||||
|
||||
NODE_PROPS_TEXT = """
|
||||
Node labels and properties (name and type) are:
|
||||
"""
|
||||
|
||||
REL_PROPS_TEXT = """
|
||||
Relationship labels and properties are:
|
||||
"""
|
||||
|
||||
REL_TEXT = """
|
||||
Nodes are connected with the following relationships:
|
||||
"""
|
||||
|
||||
|
||||
def get_schema_subset(data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return {
|
||||
"edges": [
|
||||
{
|
||||
"end_node_labels": edge["end_node_labels"],
|
||||
"properties": [
|
||||
{
|
||||
"key": prop["key"],
|
||||
"types": [
|
||||
{"type": type_item["type"].lower()}
|
||||
for type_item in prop["types"]
|
||||
],
|
||||
}
|
||||
for prop in edge["properties"]
|
||||
],
|
||||
"start_node_labels": edge["start_node_labels"],
|
||||
"type": edge["type"],
|
||||
}
|
||||
for edge in data["edges"]
|
||||
],
|
||||
"nodes": [
|
||||
{
|
||||
"labels": node["labels"],
|
||||
"properties": [
|
||||
{
|
||||
"key": prop["key"],
|
||||
"types": [
|
||||
{"type": type_item["type"].lower()}
|
||||
for type_item in prop["types"]
|
||||
],
|
||||
}
|
||||
for prop in node["properties"]
|
||||
],
|
||||
}
|
||||
for node in data["nodes"]
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def get_reformated_schema(
|
||||
nodes: List[Dict[str, Any]], rels: List[Dict[str, Any]]
|
||||
) -> Dict[str, Any]:
|
||||
return {
|
||||
"edges": [
|
||||
{
|
||||
"end_node_labels": rel["end_node_labels"],
|
||||
"properties": [
|
||||
{"key": prop[0], "types": [{"type": prop[1].lower()}]}
|
||||
for prop in rel["properties_info"]
|
||||
],
|
||||
"start_node_labels": rel["start_node_labels"],
|
||||
"type": rel["rel_type"],
|
||||
}
|
||||
for rel in rels
|
||||
],
|
||||
"nodes": [
|
||||
{
|
||||
"labels": [_remove_backticks(node["labels"])[1:]],
|
||||
"properties": [
|
||||
{
|
||||
"key": prop["key"],
|
||||
"types": [
|
||||
{"type": type_item.lower()} for type_item in prop["types"]
|
||||
],
|
||||
}
|
||||
for prop in node["properties"]
|
||||
if node["properties"][0]["key"] != ""
|
||||
],
|
||||
}
|
||||
for node in nodes
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def transform_schema_to_text(schema: Dict[str, Any]) -> str:
|
||||
node_props_data = ""
|
||||
rel_props_data = ""
|
||||
rel_data = ""
|
||||
|
||||
for node in schema["nodes"]:
|
||||
node_props_data += f"- labels: (:{':'.join(node['labels'])})\n"
|
||||
if node["properties"] == []:
|
||||
continue
|
||||
node_props_data += " properties:\n"
|
||||
for prop in node["properties"]:
|
||||
prop_types_str = " or ".join(
|
||||
{prop_types["type"] for prop_types in prop["types"]}
|
||||
)
|
||||
node_props_data += f" - {prop['key']}: {prop_types_str}\n"
|
||||
|
||||
for rel in schema["edges"]:
|
||||
rel_type = rel["type"]
|
||||
start_labels = ":".join(rel["start_node_labels"])
|
||||
end_labels = ":".join(rel["end_node_labels"])
|
||||
rel_data += f"(:{start_labels})-[:{rel_type}]->(:{end_labels})\n"
|
||||
|
||||
if rel["properties"] == []:
|
||||
continue
|
||||
|
||||
rel_props_data += f"- labels: {rel_type}\n properties:\n"
|
||||
for prop in rel["properties"]:
|
||||
prop_types_str = " or ".join(
|
||||
{prop_types["type"].lower() for prop_types in prop["types"]}
|
||||
)
|
||||
rel_props_data += f" - {prop['key']}: {prop_types_str}\n"
|
||||
|
||||
return "".join(
|
||||
[
|
||||
NODE_PROPS_TEXT + node_props_data if node_props_data else "",
|
||||
REL_PROPS_TEXT + rel_props_data if rel_props_data else "",
|
||||
REL_TEXT + rel_data if rel_data else "",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _remove_backticks(text: str) -> str:
|
||||
return text.replace("`", "")
|
||||
|
||||
|
||||
def _transform_nodes(nodes: list[Node], baseEntityLabel: bool) -> List[dict]:
|
||||
transformed_nodes = []
|
||||
for node in nodes:
|
||||
properties_dict = node.properties | {"id": node.id}
|
||||
label = (
|
||||
[_remove_backticks(node.type), BASE_ENTITY_LABEL]
|
||||
if baseEntityLabel
|
||||
else [_remove_backticks(node.type)]
|
||||
)
|
||||
node_dict = {"label": label, "properties": properties_dict}
|
||||
transformed_nodes.append(node_dict)
|
||||
return transformed_nodes
|
||||
|
||||
|
||||
def _transform_relationships(
|
||||
relationships: list[Relationship], baseEntityLabel: bool
|
||||
) -> List[dict]:
|
||||
transformed_relationships = []
|
||||
for rel in relationships:
|
||||
rel_dict = {
|
||||
"type": _remove_backticks(rel.type),
|
||||
"source_label": (
|
||||
[BASE_ENTITY_LABEL]
|
||||
if baseEntityLabel
|
||||
else [_remove_backticks(rel.source.type)]
|
||||
),
|
||||
"source_id": rel.source.id,
|
||||
"target_label": (
|
||||
[BASE_ENTITY_LABEL]
|
||||
if baseEntityLabel
|
||||
else [_remove_backticks(rel.target.type)]
|
||||
),
|
||||
"target_id": rel.target.id,
|
||||
}
|
||||
transformed_relationships.append(rel_dict)
|
||||
return transformed_relationships
|
||||
|
||||
|
||||
class MemgraphGraph(GraphStore):
|
||||
"""Memgraph wrapper for graph operations.
|
||||
|
||||
Parameters:
|
||||
url (Optional[str]): The URL of the Memgraph database server.
|
||||
username (Optional[str]): The username for database authentication.
|
||||
password (Optional[str]): The password for database authentication.
|
||||
database (str): The name of the database to connect to. Default is 'memgraph'.
|
||||
refresh_schema (bool): A flag whether to refresh schema information
|
||||
at initialization. Default is True.
|
||||
driver_config (Dict): Configuration passed to Neo4j Driver.
|
||||
|
||||
*Security note*: Make sure that the database connection uses credentials
|
||||
that are narrowly-scoped to only include necessary permissions.
|
||||
Failure to do so may result in data corruption or loss, since the calling
|
||||
code may attempt commands that would result in deletion, mutation
|
||||
of data if appropriately prompted or reading sensitive data if such
|
||||
data is present in the database.
|
||||
The best way to guard against such negative outcomes is to (as appropriate)
|
||||
limit the permissions granted to the credentials used with this tool.
|
||||
|
||||
See https://python.langchain.com/docs/security for more information.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
url: Optional[str] = None,
|
||||
username: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
database: Optional[str] = None,
|
||||
refresh_schema: bool = True,
|
||||
*,
|
||||
driver_config: Optional[Dict] = None,
|
||||
) -> None:
|
||||
"""Create a new Memgraph graph wrapper instance."""
|
||||
try:
|
||||
import neo4j
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import neo4j python package. "
|
||||
"Please install it with `pip install neo4j`."
|
||||
)
|
||||
|
||||
url = get_from_dict_or_env({"url": url}, "url", "MEMGRAPH_URI")
|
||||
|
||||
# if username and password are "", assume auth is disabled
|
||||
if username == "" and password == "":
|
||||
auth = None
|
||||
else:
|
||||
username = get_from_dict_or_env(
|
||||
{"username": username},
|
||||
"username",
|
||||
"MEMGRAPH_USERNAME",
|
||||
)
|
||||
password = get_from_dict_or_env(
|
||||
{"password": password},
|
||||
"password",
|
||||
"MEMGRAPH_PASSWORD",
|
||||
)
|
||||
auth = (username, password)
|
||||
database = get_from_dict_or_env(
|
||||
{"database": database}, "database", "MEMGRAPH_DATABASE", "memgraph"
|
||||
)
|
||||
|
||||
self._driver = neo4j.GraphDatabase.driver(
|
||||
url, auth=auth, **(driver_config or {})
|
||||
)
|
||||
|
||||
self._database = database
|
||||
self.schema: str = ""
|
||||
self.structured_schema: Dict[str, Any] = {}
|
||||
|
||||
# Verify connection
|
||||
try:
|
||||
self._driver.verify_connectivity()
|
||||
except neo4j.exceptions.ServiceUnavailable:
|
||||
raise ValueError(
|
||||
"Could not connect to Memgraph database. "
|
||||
"Please ensure that the url is correct"
|
||||
)
|
||||
except neo4j.exceptions.AuthError:
|
||||
raise ValueError(
|
||||
"Could not connect to Memgraph database. "
|
||||
"Please ensure that the username and password are correct"
|
||||
)
|
||||
|
||||
# Set schema
|
||||
if refresh_schema:
|
||||
try:
|
||||
self.refresh_schema()
|
||||
except neo4j.exceptions.ClientError as e:
|
||||
raise e
|
||||
|
||||
def close(self) -> None:
|
||||
if self._driver:
|
||||
logger.info("Closing the driver connection.")
|
||||
self._driver.close()
|
||||
self._driver = None
|
||||
|
||||
@property
|
||||
def get_schema(self) -> str:
|
||||
"""Returns the schema of the Graph database"""
|
||||
return self.schema
|
||||
|
||||
@property
|
||||
def get_structured_schema(self) -> Dict[str, Any]:
|
||||
"""Returns the structured schema of the Graph database"""
|
||||
return self.structured_schema
|
||||
|
||||
def query(self, query: str, params: dict = {}) -> List[Dict[str, Any]]:
|
||||
"""Query the graph.
|
||||
|
||||
Args:
|
||||
query (str): The Cypher query to execute.
|
||||
params (dict): The parameters to pass to the query.
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: The list of dictionaries containing the query results.
|
||||
"""
|
||||
from neo4j.exceptions import Neo4jError
|
||||
|
||||
try:
|
||||
data, _, _ = self._driver.execute_query(
|
||||
query,
|
||||
database_=self._database,
|
||||
parameters_=params,
|
||||
)
|
||||
json_data = [r.data() for r in data]
|
||||
return json_data
|
||||
except Neo4jError as e:
|
||||
if not (
|
||||
(
|
||||
( # isCallInTransactionError
|
||||
e.code == "Neo.DatabaseError.Statement.ExecutionFailed"
|
||||
or e.code
|
||||
== "Neo.DatabaseError.Transaction.TransactionStartFailed"
|
||||
)
|
||||
and "in an implicit transaction" in e.message
|
||||
)
|
||||
or ( # isPeriodicCommitError
|
||||
e.code == "Neo.ClientError.Statement.SemanticError"
|
||||
and (
|
||||
"in an open transaction is not possible" in e.message
|
||||
or "tried to execute in an explicit transaction" in e.message
|
||||
)
|
||||
)
|
||||
or (
|
||||
e.code == "Memgraph.ClientError.MemgraphError.MemgraphError"
|
||||
and ("in multicommand transactions" in e.message)
|
||||
)
|
||||
or (
|
||||
e.code == "Memgraph.ClientError.MemgraphError.MemgraphError"
|
||||
and "SchemaInfo disabled" in e.message
|
||||
)
|
||||
):
|
||||
raise
|
||||
|
||||
# fallback to allow implicit transactions
|
||||
with self._driver.session(database=self._database) as session:
|
||||
data = session.run(query, params)
|
||||
json_data = [r.data() for r in data]
|
||||
return json_data
|
||||
|
||||
def refresh_schema(self) -> None:
|
||||
"""
|
||||
Refreshes the Memgraph graph schema information.
|
||||
"""
|
||||
import ast
|
||||
|
||||
from neo4j.exceptions import Neo4jError
|
||||
|
||||
# leave schema empty if db is empty
|
||||
if self.query("MATCH (n) RETURN n LIMIT 1") == []:
|
||||
return
|
||||
|
||||
# first try with SHOW SCHEMA INFO
|
||||
try:
|
||||
result = self.query(SCHEMA_QUERY)[0].get("schema")
|
||||
if result is not None and isinstance(result, (str, ast.AST)):
|
||||
schema_result = ast.literal_eval(result)
|
||||
else:
|
||||
schema_result = result
|
||||
assert schema_result is not None
|
||||
structured_schema = get_schema_subset(schema_result)
|
||||
self.structured_schema = structured_schema
|
||||
self.schema = transform_schema_to_text(structured_schema)
|
||||
return
|
||||
except Neo4jError as e:
|
||||
if (
|
||||
e.code == "Memgraph.ClientError.MemgraphError.MemgraphError"
|
||||
and "SchemaInfo disabled" in e.message
|
||||
):
|
||||
logger.info(
|
||||
"Schema generation with SHOW SCHEMA INFO query failed. "
|
||||
"Set --schema-info-enabled=true to use SHOW SCHEMA INFO query. "
|
||||
"Falling back to alternative queries."
|
||||
)
|
||||
|
||||
# fallback on Cypher without SHOW SCHEMA INFO
|
||||
nodes = [query["output"] for query in self.query(NODE_PROPERTIES_QUERY)]
|
||||
rels = self.query(REL_QUERY)
|
||||
|
||||
structured_schema = get_reformated_schema(nodes, rels)
|
||||
self.structured_schema = structured_schema
|
||||
self.schema = transform_schema_to_text(structured_schema)
|
||||
|
||||
def add_graph_documents(
|
||||
self,
|
||||
graph_documents: List[GraphDocument],
|
||||
include_source: bool = False,
|
||||
baseEntityLabel: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Take GraphDocument as input as uses it to construct a graph in Memgraph.
|
||||
|
||||
Parameters:
|
||||
- graph_documents (List[GraphDocument]): A list of GraphDocument objects
|
||||
that contain the nodes and relationships to be added to the graph. Each
|
||||
GraphDocument should encapsulate the structure of part of the graph,
|
||||
including nodes, relationships, and the source document information.
|
||||
- include_source (bool, optional): If True, stores the source document
|
||||
and links it to nodes in the graph using the MENTIONS relationship.
|
||||
This is useful for tracing back the origin of data. Merges source
|
||||
documents based on the `id` property from the source document metadata
|
||||
if available; otherwise it calculates the MD5 hash of `page_content`
|
||||
for merging process. Defaults to False.
|
||||
- baseEntityLabel (bool, optional): If True, each newly created node
|
||||
gets a secondary __Entity__ label, which is indexed and improves import
|
||||
speed and performance. Defaults to False.
|
||||
"""
|
||||
|
||||
if baseEntityLabel:
|
||||
self.query(
|
||||
f"CREATE CONSTRAINT ON (b:{BASE_ENTITY_LABEL}) ASSERT b.id IS UNIQUE;"
|
||||
)
|
||||
self.query(f"CREATE INDEX ON :{BASE_ENTITY_LABEL}(id);")
|
||||
self.query(f"CREATE INDEX ON :{BASE_ENTITY_LABEL};")
|
||||
|
||||
for document in graph_documents:
|
||||
if include_source:
|
||||
if not document.source.metadata.get("id"):
|
||||
document.source.metadata["id"] = md5(
|
||||
document.source.page_content.encode("utf-8")
|
||||
).hexdigest()
|
||||
|
||||
self.query(INCLUDE_DOCS_QUERY, {"document": document.source.__dict__})
|
||||
|
||||
self.query(
|
||||
NODE_IMPORT_QUERY,
|
||||
{"data": _transform_nodes(document.nodes, baseEntityLabel)},
|
||||
)
|
||||
|
||||
rel_data = _transform_relationships(document.relationships, baseEntityLabel)
|
||||
self.query(
|
||||
REL_NODES_IMPORT_QUERY,
|
||||
{"data": rel_data},
|
||||
)
|
||||
self.query(
|
||||
REL_IMPORT_QUERY,
|
||||
{"data": rel_data},
|
||||
)
|
||||
|
||||
if include_source:
|
||||
self.query(
|
||||
INCLUDE_DOCS_SOURCE_QUERY,
|
||||
{"data": rel_data, "document": document.source.__dict__},
|
||||
)
|
||||
self.refresh_schema()
|
||||
@@ -0,0 +1,222 @@
|
||||
import logging
|
||||
from string import Template
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
rel_query = Template(
|
||||
"""
|
||||
MATCH ()-[e:`$edge_type`]->()
|
||||
WITH e limit 1
|
||||
MATCH (m)-[:`$edge_type`]->(n) WHERE id(m) == src(e) AND id(n) == dst(e)
|
||||
RETURN "(:" + tags(m)[0] + ")-[:$edge_type]->(:" + tags(n)[0] + ")" AS rels
|
||||
"""
|
||||
)
|
||||
|
||||
RETRY_TIMES = 3
|
||||
|
||||
|
||||
class NebulaGraph:
|
||||
"""NebulaGraph wrapper for graph operations.
|
||||
|
||||
NebulaGraph inherits methods from Neo4jGraph to bring ease to the user space.
|
||||
|
||||
*Security note*: Make sure that the database connection uses credentials
|
||||
that are narrowly-scoped to only include necessary permissions.
|
||||
Failure to do so may result in data corruption or loss, since the calling
|
||||
code may attempt commands that would result in deletion, mutation
|
||||
of data if appropriately prompted or reading sensitive data if such
|
||||
data is present in the database.
|
||||
The best way to guard against such negative outcomes is to (as appropriate)
|
||||
limit the permissions granted to the credentials used with this tool.
|
||||
|
||||
See https://python.langchain.com/docs/security for more information.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
space: str,
|
||||
username: str = "root",
|
||||
password: str = "nebula",
|
||||
address: str = "127.0.0.1",
|
||||
port: int = 9669,
|
||||
session_pool_size: int = 30,
|
||||
) -> None:
|
||||
"""Create a new NebulaGraph wrapper instance."""
|
||||
try:
|
||||
import nebula3 # noqa: F401
|
||||
import pandas # noqa: F401
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install NebulaGraph Python client and pandas first: "
|
||||
"`pip install nebula3-python pandas`"
|
||||
)
|
||||
|
||||
self.username = username
|
||||
self.password = password
|
||||
self.address = address
|
||||
self.port = port
|
||||
self.space = space
|
||||
self.session_pool_size = session_pool_size
|
||||
|
||||
self.session_pool = self._get_session_pool()
|
||||
self.schema = ""
|
||||
# Set schema
|
||||
try:
|
||||
self.refresh_schema()
|
||||
except Exception as e:
|
||||
raise ValueError(f"Could not refresh schema. Error: {e}")
|
||||
|
||||
def _get_session_pool(self) -> Any:
|
||||
assert all(
|
||||
[
|
||||
self.username,
|
||||
self.password,
|
||||
self.address,
|
||||
self.port,
|
||||
self.space,
|
||||
]
|
||||
), (
|
||||
"Please provide all of the following parameters: "
|
||||
"username, password, address, port, space"
|
||||
)
|
||||
|
||||
from nebula3.Config import SessionPoolConfig
|
||||
from nebula3.Exception import AuthFailedException, InValidHostname
|
||||
from nebula3.gclient.net.SessionPool import SessionPool
|
||||
|
||||
config = SessionPoolConfig()
|
||||
config.max_size = self.session_pool_size
|
||||
|
||||
try:
|
||||
session_pool = SessionPool(
|
||||
self.username,
|
||||
self.password,
|
||||
self.space,
|
||||
[(self.address, self.port)],
|
||||
)
|
||||
except InValidHostname:
|
||||
raise ValueError(
|
||||
"Could not connect to NebulaGraph database. "
|
||||
"Please ensure that the address and port are correct"
|
||||
)
|
||||
|
||||
try:
|
||||
session_pool.init(config)
|
||||
except AuthFailedException:
|
||||
raise ValueError(
|
||||
"Could not connect to NebulaGraph database. "
|
||||
"Please ensure that the username and password are correct"
|
||||
)
|
||||
except RuntimeError as e:
|
||||
raise ValueError(f"Error initializing session pool. Error: {e}")
|
||||
|
||||
return session_pool
|
||||
|
||||
def __del__(self) -> None:
|
||||
try:
|
||||
self.session_pool.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not close session pool. Error: {e}")
|
||||
|
||||
@property
|
||||
def get_schema(self) -> str:
|
||||
"""Returns the schema of the NebulaGraph database"""
|
||||
return self.schema
|
||||
|
||||
def execute(self, query: str, params: Optional[dict] = None, retry: int = 0) -> Any:
|
||||
"""Query NebulaGraph database."""
|
||||
from nebula3.Exception import IOErrorException, NoValidSessionException
|
||||
from nebula3.fbthrift.transport.TTransport import TTransportException
|
||||
|
||||
params = params or {}
|
||||
try:
|
||||
result = self.session_pool.execute_parameter(query, params)
|
||||
if not result.is_succeeded():
|
||||
logger.warning(
|
||||
f"Error executing query to NebulaGraph. "
|
||||
f"Error: {result.error_msg()}\n"
|
||||
f"Query: {query} \n"
|
||||
)
|
||||
return result
|
||||
|
||||
except NoValidSessionException:
|
||||
logger.warning(
|
||||
f"No valid session found in session pool. "
|
||||
f"Please consider increasing the session pool size. "
|
||||
f"Current size: {self.session_pool_size}"
|
||||
)
|
||||
raise ValueError(
|
||||
f"No valid session found in session pool. "
|
||||
f"Please consider increasing the session pool size. "
|
||||
f"Current size: {self.session_pool_size}"
|
||||
)
|
||||
|
||||
except RuntimeError as e:
|
||||
if retry < RETRY_TIMES:
|
||||
retry += 1
|
||||
logger.warning(
|
||||
f"Error executing query to NebulaGraph. "
|
||||
f"Retrying ({retry}/{RETRY_TIMES})...\n"
|
||||
f"query: {query} \n"
|
||||
f"Error: {e}"
|
||||
)
|
||||
return self.execute(query, params, retry)
|
||||
else:
|
||||
raise ValueError(f"Error executing query to NebulaGraph. Error: {e}")
|
||||
|
||||
except (TTransportException, IOErrorException):
|
||||
# connection issue, try to recreate session pool
|
||||
if retry < RETRY_TIMES:
|
||||
retry += 1
|
||||
logger.warning(
|
||||
f"Connection issue with NebulaGraph. "
|
||||
f"Retrying ({retry}/{RETRY_TIMES})...\n to recreate session pool"
|
||||
)
|
||||
self.session_pool = self._get_session_pool()
|
||||
return self.execute(query, params, retry)
|
||||
|
||||
def refresh_schema(self) -> None:
|
||||
"""
|
||||
Refreshes the NebulaGraph schema information.
|
||||
"""
|
||||
tags_schema, edge_types_schema, relationships = [], [], []
|
||||
for tag in self.execute("SHOW TAGS").column_values("Name"):
|
||||
tag_name = tag.cast()
|
||||
tag_schema = {"tag": tag_name, "properties": []}
|
||||
r = self.execute(f"DESCRIBE TAG `{tag_name}`")
|
||||
props, types = r.column_values("Field"), r.column_values("Type")
|
||||
for i in range(r.row_size()):
|
||||
tag_schema["properties"].append((props[i].cast(), types[i].cast()))
|
||||
tags_schema.append(tag_schema)
|
||||
for edge_type in self.execute("SHOW EDGES").column_values("Name"):
|
||||
edge_type_name = edge_type.cast()
|
||||
edge_schema = {"edge": edge_type_name, "properties": []}
|
||||
r = self.execute(f"DESCRIBE EDGE `{edge_type_name}`")
|
||||
props, types = r.column_values("Field"), r.column_values("Type")
|
||||
for i in range(r.row_size()):
|
||||
edge_schema["properties"].append((props[i].cast(), types[i].cast()))
|
||||
edge_types_schema.append(edge_schema)
|
||||
|
||||
# build relationships types
|
||||
r = self.execute(
|
||||
rel_query.substitute(edge_type=edge_type_name)
|
||||
).column_values("rels")
|
||||
if len(r) > 0:
|
||||
relationships.append(r[0].cast())
|
||||
|
||||
self.schema = (
|
||||
f"Node properties: {tags_schema}\n"
|
||||
f"Edge properties: {edge_types_schema}\n"
|
||||
f"Relationships: {relationships}\n"
|
||||
)
|
||||
|
||||
def query(self, query: str, retry: int = 0) -> Dict[str, Any]:
|
||||
result = self.execute(query, retry=retry)
|
||||
columns = result.keys()
|
||||
d: Dict[str, list] = {}
|
||||
for col_num in range(result.col_size()):
|
||||
col_name = columns[col_num]
|
||||
col_list = result.column_values(col_name)
|
||||
d[col_name] = [x.cast() for x in col_list]
|
||||
return d
|
||||
848
venv/Lib/site-packages/langchain_community/graphs/neo4j_graph.py
Normal file
848
venv/Lib/site-packages/langchain_community/graphs/neo4j_graph.py
Normal file
@@ -0,0 +1,848 @@
|
||||
from hashlib import md5
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_core._api.deprecation import deprecated
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
|
||||
from langchain_community.graphs.graph_document import GraphDocument
|
||||
from langchain_community.graphs.graph_store import GraphStore
|
||||
|
||||
BASE_ENTITY_LABEL = "__Entity__"
|
||||
EXCLUDED_LABELS = ["_Bloom_Perspective_", "_Bloom_Scene_"]
|
||||
EXCLUDED_RELS = ["_Bloom_HAS_SCENE_"]
|
||||
EXHAUSTIVE_SEARCH_LIMIT = 10000
|
||||
LIST_LIMIT = 128
|
||||
# Threshold for returning all available prop values in graph schema
|
||||
DISTINCT_VALUE_LIMIT = 10
|
||||
|
||||
node_properties_query = """
|
||||
CALL apoc.meta.data()
|
||||
YIELD label, other, elementType, type, property
|
||||
WHERE NOT type = "RELATIONSHIP" AND elementType = "node"
|
||||
AND NOT label IN $EXCLUDED_LABELS
|
||||
WITH label AS nodeLabels, collect({property:property, type:type}) AS properties
|
||||
RETURN {labels: nodeLabels, properties: properties} AS output
|
||||
|
||||
"""
|
||||
|
||||
rel_properties_query = """
|
||||
CALL apoc.meta.data()
|
||||
YIELD label, other, elementType, type, property
|
||||
WHERE NOT type = "RELATIONSHIP" AND elementType = "relationship"
|
||||
AND NOT label in $EXCLUDED_LABELS
|
||||
WITH label AS nodeLabels, collect({property:property, type:type}) AS properties
|
||||
RETURN {type: nodeLabels, properties: properties} AS output
|
||||
"""
|
||||
|
||||
rel_query = """
|
||||
CALL apoc.meta.data()
|
||||
YIELD label, other, elementType, type, property
|
||||
WHERE type = "RELATIONSHIP" AND elementType = "node"
|
||||
UNWIND other AS other_node
|
||||
WITH * WHERE NOT label IN $EXCLUDED_LABELS
|
||||
AND NOT other_node IN $EXCLUDED_LABELS
|
||||
RETURN {start: label, type: property, end: toString(other_node)} AS output
|
||||
"""
|
||||
|
||||
include_docs_query = (
|
||||
"MERGE (d:Document {id:$document.metadata.id}) "
|
||||
"SET d.text = $document.page_content "
|
||||
"SET d += $document.metadata "
|
||||
"WITH d "
|
||||
)
|
||||
|
||||
|
||||
@deprecated(
|
||||
since="0.3.8",
|
||||
removal="1.0",
|
||||
alternative_import="langchain_neo4j.graphs.neo4j_graph.clean_string_values",
|
||||
)
|
||||
def clean_string_values(text: str) -> str:
|
||||
"""Clean string values for schema.
|
||||
|
||||
Cleans the input text by replacing newline and carriage return characters.
|
||||
|
||||
Args:
|
||||
text (str): The input text to clean.
|
||||
|
||||
Returns:
|
||||
str: The cleaned text.
|
||||
"""
|
||||
return text.replace("\n", " ").replace("\r", " ")
|
||||
|
||||
|
||||
@deprecated(
|
||||
since="0.3.8",
|
||||
removal="1.0",
|
||||
alternative_import="langchain_neo4j.graphs.neo4j_graph.value_sanitize",
|
||||
)
|
||||
def value_sanitize(d: Any) -> Any:
|
||||
"""Sanitize the input dictionary or list.
|
||||
|
||||
Sanitizes the input by removing embedding-like values,
|
||||
lists with more than 128 elements, that are mostly irrelevant for
|
||||
generating answers in a LLM context. These properties, if left in
|
||||
results, can occupy significant context space and detract from
|
||||
the LLM's performance by introducing unnecessary noise and cost.
|
||||
|
||||
Args:
|
||||
d (Any): The input dictionary or list to sanitize.
|
||||
|
||||
Returns:
|
||||
Any: The sanitized dictionary or list.
|
||||
"""
|
||||
if isinstance(d, dict):
|
||||
new_dict = {}
|
||||
for key, value in d.items():
|
||||
if isinstance(value, dict):
|
||||
sanitized_value = value_sanitize(value)
|
||||
if (
|
||||
sanitized_value is not None
|
||||
): # Check if the sanitized value is not None
|
||||
new_dict[key] = sanitized_value
|
||||
elif isinstance(value, list):
|
||||
if len(value) < LIST_LIMIT:
|
||||
sanitized_value = value_sanitize(value)
|
||||
if (
|
||||
sanitized_value is not None
|
||||
): # Check if the sanitized value is not None
|
||||
new_dict[key] = sanitized_value
|
||||
# Do not include the key if the list is oversized
|
||||
else:
|
||||
new_dict[key] = value
|
||||
return new_dict
|
||||
elif isinstance(d, list):
|
||||
if len(d) < LIST_LIMIT:
|
||||
return [
|
||||
value_sanitize(item) for item in d if value_sanitize(item) is not None
|
||||
]
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
return d
|
||||
|
||||
|
||||
@deprecated(
|
||||
since="0.3.8",
|
||||
removal="1.0",
|
||||
alternative_import="langchain_neo4j.graphs.neo4j_graph._get_node_import_query",
|
||||
)
|
||||
def _get_node_import_query(baseEntityLabel: bool, include_source: bool) -> str:
|
||||
if baseEntityLabel:
|
||||
return (
|
||||
f"{include_docs_query if include_source else ''}"
|
||||
"UNWIND $data AS row "
|
||||
f"MERGE (source:`{BASE_ENTITY_LABEL}` {{id: row.id}}) "
|
||||
"SET source += row.properties "
|
||||
f"{'MERGE (d)-[:MENTIONS]->(source) ' if include_source else ''}"
|
||||
"WITH source, row "
|
||||
"CALL apoc.create.addLabels( source, [row.type] ) YIELD node "
|
||||
"RETURN distinct 'done' AS result"
|
||||
)
|
||||
else:
|
||||
return (
|
||||
f"{include_docs_query if include_source else ''}"
|
||||
"UNWIND $data AS row "
|
||||
"CALL apoc.merge.node([row.type], {id: row.id}, "
|
||||
"row.properties, {}) YIELD node "
|
||||
f"{'MERGE (d)-[:MENTIONS]->(node) ' if include_source else ''}"
|
||||
"RETURN distinct 'done' AS result"
|
||||
)
|
||||
|
||||
|
||||
@deprecated(
|
||||
since="0.3.8",
|
||||
removal="1.0",
|
||||
alternative_import="langchain_neo4j.graphs.neo4j_graph._get_rel_import_query",
|
||||
)
|
||||
def _get_rel_import_query(baseEntityLabel: bool) -> str:
|
||||
if baseEntityLabel:
|
||||
return (
|
||||
"UNWIND $data AS row "
|
||||
f"MERGE (source:`{BASE_ENTITY_LABEL}` {{id: row.source}}) "
|
||||
f"MERGE (target:`{BASE_ENTITY_LABEL}` {{id: row.target}}) "
|
||||
"WITH source, target, row "
|
||||
"CALL apoc.merge.relationship(source, row.type, "
|
||||
"{}, row.properties, target) YIELD rel "
|
||||
"RETURN distinct 'done'"
|
||||
)
|
||||
else:
|
||||
return (
|
||||
"UNWIND $data AS row "
|
||||
"CALL apoc.merge.node([row.source_label], {id: row.source},"
|
||||
"{}, {}) YIELD node as source "
|
||||
"CALL apoc.merge.node([row.target_label], {id: row.target},"
|
||||
"{}, {}) YIELD node as target "
|
||||
"CALL apoc.merge.relationship(source, row.type, "
|
||||
"{}, row.properties, target) YIELD rel "
|
||||
"RETURN distinct 'done'"
|
||||
)
|
||||
|
||||
|
||||
@deprecated(
|
||||
since="0.3.8",
|
||||
removal="1.0",
|
||||
alternative_import="langchain_neo4j.graphs.neo4j_graph._format_schema",
|
||||
)
|
||||
def _format_schema(schema: Dict, is_enhanced: bool) -> str:
|
||||
formatted_node_props = []
|
||||
formatted_rel_props = []
|
||||
if is_enhanced:
|
||||
# Enhanced formatting for nodes
|
||||
for node_type, properties in schema["node_props"].items():
|
||||
formatted_node_props.append(f"- **{node_type}**")
|
||||
for prop in properties:
|
||||
example = ""
|
||||
if prop["type"] == "STRING" and prop.get("values"):
|
||||
if prop.get("distinct_count", 11) > DISTINCT_VALUE_LIMIT:
|
||||
example = (
|
||||
f'Example: "{clean_string_values(prop["values"][0])}"'
|
||||
if prop["values"]
|
||||
else ""
|
||||
)
|
||||
else: # If less than 10 possible values return all
|
||||
example = (
|
||||
(
|
||||
"Available options: "
|
||||
f"{[clean_string_values(el) for el in prop['values']]}"
|
||||
)
|
||||
if prop["values"]
|
||||
else ""
|
||||
)
|
||||
|
||||
elif prop["type"] in [
|
||||
"INTEGER",
|
||||
"FLOAT",
|
||||
"DATE",
|
||||
"DATE_TIME",
|
||||
"LOCAL_DATE_TIME",
|
||||
]:
|
||||
if prop.get("min") is not None:
|
||||
example = f"Min: {prop['min']}, Max: {prop['max']}"
|
||||
else:
|
||||
example = (
|
||||
f'Example: "{prop["values"][0]}"'
|
||||
if prop.get("values")
|
||||
else ""
|
||||
)
|
||||
elif prop["type"] == "LIST":
|
||||
# Skip embeddings
|
||||
if not prop.get("min_size") or prop["min_size"] > LIST_LIMIT:
|
||||
continue
|
||||
example = (
|
||||
f"Min Size: {prop['min_size']}, Max Size: {prop['max_size']}"
|
||||
)
|
||||
formatted_node_props.append(
|
||||
f" - `{prop['property']}`: {prop['type']} {example}"
|
||||
)
|
||||
|
||||
# Enhanced formatting for relationships
|
||||
for rel_type, properties in schema["rel_props"].items():
|
||||
formatted_rel_props.append(f"- **{rel_type}**")
|
||||
for prop in properties:
|
||||
example = ""
|
||||
if prop["type"] == "STRING":
|
||||
if prop.get("distinct_count", 11) > DISTINCT_VALUE_LIMIT:
|
||||
example = (
|
||||
f'Example: "{clean_string_values(prop["values"][0])}"'
|
||||
if prop["values"]
|
||||
else ""
|
||||
)
|
||||
else: # If less than 10 possible values return all
|
||||
example = (
|
||||
(
|
||||
"Available options: "
|
||||
f"{[clean_string_values(el) for el in prop['values']]}"
|
||||
)
|
||||
if prop["values"]
|
||||
else ""
|
||||
)
|
||||
elif prop["type"] in [
|
||||
"INTEGER",
|
||||
"FLOAT",
|
||||
"DATE",
|
||||
"DATE_TIME",
|
||||
"LOCAL_DATE_TIME",
|
||||
]:
|
||||
if prop.get("min"): # If we have min/max
|
||||
example = f"Min: {prop['min']}, Max: {prop['max']}"
|
||||
else: # return a single value
|
||||
example = (
|
||||
f'Example: "{prop["values"][0]}"' if prop["values"] else ""
|
||||
)
|
||||
elif prop["type"] == "LIST":
|
||||
# Skip embeddings
|
||||
if not prop.get("min_size") or prop["min_size"] > LIST_LIMIT:
|
||||
continue
|
||||
example = (
|
||||
f"Min Size: {prop['min_size']}, Max Size: {prop['max_size']}"
|
||||
)
|
||||
formatted_rel_props.append(
|
||||
f" - `{prop['property']}: {prop['type']}` {example}"
|
||||
)
|
||||
else:
|
||||
# Format node properties
|
||||
for label, props in schema["node_props"].items():
|
||||
props_str = ", ".join(
|
||||
[f"{prop['property']}: {prop['type']}" for prop in props]
|
||||
)
|
||||
formatted_node_props.append(f"{label} {{{props_str}}}")
|
||||
|
||||
# Format relationship properties using structured_schema
|
||||
for type, props in schema["rel_props"].items():
|
||||
props_str = ", ".join(
|
||||
[f"{prop['property']}: {prop['type']}" for prop in props]
|
||||
)
|
||||
formatted_rel_props.append(f"{type} {{{props_str}}}")
|
||||
|
||||
# Format relationships
|
||||
formatted_rels = [
|
||||
f"(:{el['start']})-[:{el['type']}]->(:{el['end']})"
|
||||
for el in schema["relationships"]
|
||||
]
|
||||
|
||||
return "\n".join(
|
||||
[
|
||||
"Node properties:",
|
||||
"\n".join(formatted_node_props),
|
||||
"Relationship properties:",
|
||||
"\n".join(formatted_rel_props),
|
||||
"The relationships:",
|
||||
"\n".join(formatted_rels),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@deprecated(
|
||||
since="0.3.8",
|
||||
removal="1.0",
|
||||
alternative_import="langchain_neo4j.graphs.neo4j_graph._remove_backticks",
|
||||
)
|
||||
def _remove_backticks(text: str) -> str:
|
||||
return text.replace("`", "")
|
||||
|
||||
|
||||
@deprecated(
|
||||
since="0.3.8",
|
||||
removal="1.0",
|
||||
alternative_import="langchain_neo4j.Neo4jGraph",
|
||||
)
|
||||
class Neo4jGraph(GraphStore):
|
||||
"""Neo4j database wrapper for various graph operations.
|
||||
|
||||
Parameters:
|
||||
url (Optional[str]): The URL of the Neo4j database server.
|
||||
username (Optional[str]): The username for database authentication.
|
||||
password (Optional[str]): The password for database authentication.
|
||||
database (str): The name of the database to connect to. Default is 'neo4j'.
|
||||
timeout (Optional[float]): The timeout for transactions in seconds.
|
||||
Useful for terminating long-running queries.
|
||||
By default, there is no timeout set.
|
||||
sanitize (bool): A flag to indicate whether to remove lists with
|
||||
more than 128 elements from results. Useful for removing
|
||||
embedding-like properties from database responses. Default is False.
|
||||
refresh_schema (bool): A flag whether to refresh schema information
|
||||
at initialization. Default is True.
|
||||
enhanced_schema (bool): A flag whether to scan the database for
|
||||
example values and use them in the graph schema. Default is False.
|
||||
driver_config (Dict): Configuration passed to Neo4j Driver.
|
||||
|
||||
*Security note*: Make sure that the database connection uses credentials
|
||||
that are narrowly-scoped to only include necessary permissions.
|
||||
Failure to do so may result in data corruption or loss, since the calling
|
||||
code may attempt commands that would result in deletion, mutation
|
||||
of data if appropriately prompted or reading sensitive data if such
|
||||
data is present in the database.
|
||||
The best way to guard against such negative outcomes is to (as appropriate)
|
||||
limit the permissions granted to the credentials used with this tool.
|
||||
|
||||
See https://python.langchain.com/docs/security for more information.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
url: Optional[str] = None,
|
||||
username: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
database: Optional[str] = None,
|
||||
timeout: Optional[float] = None,
|
||||
sanitize: bool = False,
|
||||
refresh_schema: bool = True,
|
||||
*,
|
||||
driver_config: Optional[Dict] = None,
|
||||
enhanced_schema: bool = False,
|
||||
) -> None:
|
||||
"""Create a new Neo4j graph wrapper instance."""
|
||||
try:
|
||||
import neo4j
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import neo4j python package. "
|
||||
"Please install it with `pip install neo4j`."
|
||||
)
|
||||
|
||||
url = get_from_dict_or_env({"url": url}, "url", "NEO4J_URI")
|
||||
# if username and password are "", assume Neo4j auth is disabled
|
||||
if username == "" and password == "":
|
||||
auth = None
|
||||
else:
|
||||
username = get_from_dict_or_env(
|
||||
{"username": username},
|
||||
"username",
|
||||
"NEO4J_USERNAME",
|
||||
)
|
||||
password = get_from_dict_or_env(
|
||||
{"password": password},
|
||||
"password",
|
||||
"NEO4J_PASSWORD",
|
||||
)
|
||||
auth = (username, password)
|
||||
database = get_from_dict_or_env(
|
||||
{"database": database}, "database", "NEO4J_DATABASE", "neo4j"
|
||||
)
|
||||
|
||||
self._driver = neo4j.GraphDatabase.driver(
|
||||
url, auth=auth, **(driver_config or {})
|
||||
)
|
||||
self._database = database
|
||||
self.timeout = timeout
|
||||
self.sanitize = sanitize
|
||||
self._enhanced_schema = enhanced_schema
|
||||
self.schema: str = ""
|
||||
self.structured_schema: Dict[str, Any] = {}
|
||||
# Verify connection
|
||||
try:
|
||||
self._driver.verify_connectivity()
|
||||
except neo4j.exceptions.ServiceUnavailable:
|
||||
raise ValueError(
|
||||
"Could not connect to Neo4j database. "
|
||||
"Please ensure that the url is correct"
|
||||
)
|
||||
except neo4j.exceptions.AuthError:
|
||||
raise ValueError(
|
||||
"Could not connect to Neo4j database. "
|
||||
"Please ensure that the username and password are correct"
|
||||
)
|
||||
# Set schema
|
||||
if refresh_schema:
|
||||
try:
|
||||
self.refresh_schema()
|
||||
except neo4j.exceptions.ClientError as e:
|
||||
if e.code == "Neo.ClientError.Procedure.ProcedureNotFound":
|
||||
raise ValueError(
|
||||
"Could not use APOC procedures. "
|
||||
"Please ensure the APOC plugin is installed in Neo4j and that "
|
||||
"'apoc.meta.data()' is allowed in Neo4j configuration "
|
||||
)
|
||||
raise e
|
||||
|
||||
@property
|
||||
def get_schema(self) -> str:
|
||||
"""Returns the schema of the Graph"""
|
||||
return self.schema
|
||||
|
||||
@property
|
||||
def get_structured_schema(self) -> Dict[str, Any]:
|
||||
"""Returns the structured schema of the Graph"""
|
||||
return self.structured_schema
|
||||
|
||||
def query(
|
||||
self,
|
||||
query: str,
|
||||
params: dict = {},
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Query Neo4j database.
|
||||
|
||||
Args:
|
||||
query (str): The Cypher query to execute.
|
||||
params (dict): The parameters to pass to the query.
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: The list of dictionaries containing the query results.
|
||||
"""
|
||||
from neo4j import Query
|
||||
from neo4j.exceptions import Neo4jError
|
||||
|
||||
try:
|
||||
data, _, _ = self._driver.execute_query(
|
||||
Query(text=query, timeout=self.timeout),
|
||||
database_=self._database,
|
||||
parameters_=params,
|
||||
)
|
||||
json_data = [r.data() for r in data]
|
||||
if self.sanitize:
|
||||
json_data = [value_sanitize(el) for el in json_data]
|
||||
return json_data
|
||||
except Neo4jError as e:
|
||||
if not (
|
||||
(
|
||||
( # isCallInTransactionError
|
||||
e.code == "Neo.DatabaseError.Statement.ExecutionFailed"
|
||||
or e.code
|
||||
== "Neo.DatabaseError.Transaction.TransactionStartFailed"
|
||||
)
|
||||
and "in an implicit transaction" in e.message
|
||||
)
|
||||
or ( # isPeriodicCommitError
|
||||
e.code == "Neo.ClientError.Statement.SemanticError"
|
||||
and (
|
||||
"in an open transaction is not possible" in e.message
|
||||
or "tried to execute in an explicit transaction" in e.message
|
||||
)
|
||||
)
|
||||
):
|
||||
raise
|
||||
# fallback to allow implicit transactions
|
||||
with self._driver.session(database=self._database) as session:
|
||||
data = session.run(Query(text=query, timeout=self.timeout), params)
|
||||
json_data = [r.data() for r in data]
|
||||
if self.sanitize:
|
||||
json_data = [value_sanitize(el) for el in json_data]
|
||||
return json_data
|
||||
|
||||
def refresh_schema(self) -> None:
|
||||
"""
|
||||
Refreshes the Neo4j graph schema information.
|
||||
"""
|
||||
from neo4j.exceptions import ClientError, CypherTypeError
|
||||
|
||||
node_properties = [
|
||||
el["output"]
|
||||
for el in self.query(
|
||||
node_properties_query,
|
||||
params={"EXCLUDED_LABELS": EXCLUDED_LABELS + [BASE_ENTITY_LABEL]},
|
||||
)
|
||||
]
|
||||
rel_properties = [
|
||||
el["output"]
|
||||
for el in self.query(
|
||||
rel_properties_query, params={"EXCLUDED_LABELS": EXCLUDED_RELS}
|
||||
)
|
||||
]
|
||||
relationships = [
|
||||
el["output"]
|
||||
for el in self.query(
|
||||
rel_query,
|
||||
params={"EXCLUDED_LABELS": EXCLUDED_LABELS + [BASE_ENTITY_LABEL]},
|
||||
)
|
||||
]
|
||||
|
||||
# Get constraints & indexes
|
||||
try:
|
||||
constraint = self.query("SHOW CONSTRAINTS")
|
||||
index = self.query(
|
||||
"CALL apoc.schema.nodes() YIELD label, properties, type, size, "
|
||||
"valuesSelectivity WHERE type = 'RANGE' RETURN *, "
|
||||
"size * valuesSelectivity as distinctValues"
|
||||
)
|
||||
except (
|
||||
ClientError
|
||||
): # Read-only user might not have access to schema information
|
||||
constraint = []
|
||||
index = []
|
||||
|
||||
self.structured_schema = {
|
||||
"node_props": {el["labels"]: el["properties"] for el in node_properties},
|
||||
"rel_props": {el["type"]: el["properties"] for el in rel_properties},
|
||||
"relationships": relationships,
|
||||
"metadata": {"constraint": constraint, "index": index},
|
||||
}
|
||||
if self._enhanced_schema:
|
||||
schema_counts = self.query(
|
||||
"CALL apoc.meta.graphSample() YIELD nodes, relationships "
|
||||
"RETURN nodes, [rel in relationships | {name:apoc.any.property"
|
||||
"(rel, 'type'), count: apoc.any.property(rel, 'count')}]"
|
||||
" AS relationships"
|
||||
)
|
||||
# Update node info
|
||||
for node in schema_counts[0]["nodes"]:
|
||||
# Skip bloom labels
|
||||
if node["name"] in EXCLUDED_LABELS:
|
||||
continue
|
||||
node_props = self.structured_schema["node_props"].get(node["name"])
|
||||
if not node_props: # The node has no properties
|
||||
continue
|
||||
enhanced_cypher = self._enhanced_schema_cypher(
|
||||
node["name"], node_props, node["count"] < EXHAUSTIVE_SEARCH_LIMIT
|
||||
)
|
||||
# Due to schema-flexible nature of neo4j errors can happen
|
||||
try:
|
||||
enhanced_info = self.query(enhanced_cypher)[0]["output"]
|
||||
for prop in node_props:
|
||||
if prop["property"] in enhanced_info:
|
||||
prop.update(enhanced_info[prop["property"]])
|
||||
except CypherTypeError:
|
||||
continue
|
||||
# Update rel info
|
||||
for rel in schema_counts[0]["relationships"]:
|
||||
# Skip bloom labels
|
||||
if rel["name"] in EXCLUDED_RELS:
|
||||
continue
|
||||
rel_props = self.structured_schema["rel_props"].get(rel["name"])
|
||||
if not rel_props: # The rel has no properties
|
||||
continue
|
||||
enhanced_cypher = self._enhanced_schema_cypher(
|
||||
rel["name"],
|
||||
rel_props,
|
||||
rel["count"] < EXHAUSTIVE_SEARCH_LIMIT,
|
||||
is_relationship=True,
|
||||
)
|
||||
try:
|
||||
enhanced_info = self.query(enhanced_cypher)[0]["output"]
|
||||
for prop in rel_props:
|
||||
if prop["property"] in enhanced_info:
|
||||
prop.update(enhanced_info[prop["property"]])
|
||||
# Due to schema-flexible nature of neo4j errors can happen
|
||||
except CypherTypeError:
|
||||
continue
|
||||
|
||||
schema = _format_schema(self.structured_schema, self._enhanced_schema)
|
||||
|
||||
self.schema = schema
|
||||
|
||||
def add_graph_documents(
|
||||
self,
|
||||
graph_documents: List[GraphDocument],
|
||||
include_source: bool = False,
|
||||
baseEntityLabel: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
This method constructs nodes and relationships in the graph based on the
|
||||
provided GraphDocument objects.
|
||||
|
||||
Parameters:
|
||||
- graph_documents (List[GraphDocument]): A list of GraphDocument objects
|
||||
that contain the nodes and relationships to be added to the graph. Each
|
||||
GraphDocument should encapsulate the structure of part of the graph,
|
||||
including nodes, relationships, and the source document information.
|
||||
- include_source (bool, optional): If True, stores the source document
|
||||
and links it to nodes in the graph using the MENTIONS relationship.
|
||||
This is useful for tracing back the origin of data. Merges source
|
||||
documents based on the `id` property from the source document metadata
|
||||
if available; otherwise it calculates the MD5 hash of `page_content`
|
||||
for merging process. Defaults to False.
|
||||
- baseEntityLabel (bool, optional): If True, each newly created node
|
||||
gets a secondary __Entity__ label, which is indexed and improves import
|
||||
speed and performance. Defaults to False.
|
||||
"""
|
||||
if baseEntityLabel: # Check if constraint already exists
|
||||
constraint_exists = any(
|
||||
[
|
||||
el["labelsOrTypes"] == [BASE_ENTITY_LABEL]
|
||||
and el["properties"] == ["id"]
|
||||
for el in self.structured_schema.get("metadata", {}).get(
|
||||
"constraint", []
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
if not constraint_exists:
|
||||
# Create constraint
|
||||
self.query(
|
||||
f"CREATE CONSTRAINT IF NOT EXISTS FOR (b:{BASE_ENTITY_LABEL}) "
|
||||
"REQUIRE b.id IS UNIQUE;"
|
||||
)
|
||||
self.refresh_schema() # Refresh constraint information
|
||||
|
||||
node_import_query = _get_node_import_query(baseEntityLabel, include_source)
|
||||
rel_import_query = _get_rel_import_query(baseEntityLabel)
|
||||
for document in graph_documents:
|
||||
if not document.source.metadata.get("id"):
|
||||
document.source.metadata["id"] = md5(
|
||||
document.source.page_content.encode("utf-8")
|
||||
).hexdigest()
|
||||
|
||||
# Remove backticks from node types
|
||||
for node in document.nodes:
|
||||
node.type = _remove_backticks(node.type)
|
||||
# Import nodes
|
||||
self.query(
|
||||
node_import_query,
|
||||
{
|
||||
"data": [el.__dict__ for el in document.nodes],
|
||||
"document": document.source.__dict__,
|
||||
},
|
||||
)
|
||||
# Import relationships
|
||||
self.query(
|
||||
rel_import_query,
|
||||
{
|
||||
"data": [
|
||||
{
|
||||
"source": el.source.id,
|
||||
"source_label": _remove_backticks(el.source.type),
|
||||
"target": el.target.id,
|
||||
"target_label": _remove_backticks(el.target.type),
|
||||
"type": _remove_backticks(
|
||||
el.type.replace(" ", "_").upper()
|
||||
),
|
||||
"properties": el.properties,
|
||||
}
|
||||
for el in document.relationships
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
def _enhanced_schema_cypher(
|
||||
self,
|
||||
label_or_type: str,
|
||||
properties: List[Dict[str, Any]],
|
||||
exhaustive: bool,
|
||||
is_relationship: bool = False,
|
||||
) -> str:
|
||||
if is_relationship:
|
||||
match_clause = f"MATCH ()-[n:`{label_or_type}`]->()"
|
||||
else:
|
||||
match_clause = f"MATCH (n:`{label_or_type}`)"
|
||||
|
||||
with_clauses = []
|
||||
return_clauses = []
|
||||
output_dict = {}
|
||||
if exhaustive:
|
||||
for prop in properties:
|
||||
prop_name = prop["property"]
|
||||
prop_type = prop["type"]
|
||||
if prop_type == "STRING":
|
||||
with_clauses.append(
|
||||
(
|
||||
f"collect(distinct substring(toString(n.`{prop_name}`)"
|
||||
f", 0, 50)) AS `{prop_name}_values`"
|
||||
)
|
||||
)
|
||||
return_clauses.append(
|
||||
(
|
||||
f"values:`{prop_name}_values`[..{DISTINCT_VALUE_LIMIT}],"
|
||||
f" distinct_count: size(`{prop_name}_values`)"
|
||||
)
|
||||
)
|
||||
elif prop_type in [
|
||||
"INTEGER",
|
||||
"FLOAT",
|
||||
"DATE",
|
||||
"DATE_TIME",
|
||||
"LOCAL_DATE_TIME",
|
||||
]:
|
||||
with_clauses.append(f"min(n.`{prop_name}`) AS `{prop_name}_min`")
|
||||
with_clauses.append(f"max(n.`{prop_name}`) AS `{prop_name}_max`")
|
||||
with_clauses.append(
|
||||
f"count(distinct n.`{prop_name}`) AS `{prop_name}_distinct`"
|
||||
)
|
||||
return_clauses.append(
|
||||
(
|
||||
f"min: toString(`{prop_name}_min`), "
|
||||
f"max: toString(`{prop_name}_max`), "
|
||||
f"distinct_count: `{prop_name}_distinct`"
|
||||
)
|
||||
)
|
||||
elif prop_type == "LIST":
|
||||
with_clauses.append(
|
||||
(
|
||||
f"min(size(n.`{prop_name}`)) AS `{prop_name}_size_min`, "
|
||||
f"max(size(n.`{prop_name}`)) AS `{prop_name}_size_max`"
|
||||
)
|
||||
)
|
||||
return_clauses.append(
|
||||
f"min_size: `{prop_name}_size_min`, "
|
||||
f"max_size: `{prop_name}_size_max`"
|
||||
)
|
||||
elif prop_type in ["BOOLEAN", "POINT", "DURATION"]:
|
||||
continue
|
||||
output_dict[prop_name] = "{" + return_clauses.pop() + "}"
|
||||
else:
|
||||
# Just sample 5 random nodes
|
||||
match_clause += " WITH n LIMIT 5"
|
||||
for prop in properties:
|
||||
prop_name = prop["property"]
|
||||
prop_type = prop["type"]
|
||||
|
||||
# Check if indexed property, we can still do exhaustive
|
||||
prop_index = [
|
||||
el
|
||||
for el in self.structured_schema["metadata"]["index"]
|
||||
if el["label"] == label_or_type
|
||||
and el["properties"] == [prop_name]
|
||||
and el["type"] == "RANGE"
|
||||
]
|
||||
if prop_type == "STRING":
|
||||
if (
|
||||
prop_index
|
||||
and prop_index[0].get("size") > 0
|
||||
and prop_index[0].get("distinctValues") <= DISTINCT_VALUE_LIMIT
|
||||
):
|
||||
distinct_values = self.query(
|
||||
f"CALL apoc.schema.properties.distinct("
|
||||
f"'{label_or_type}', '{prop_name}') YIELD value"
|
||||
)[0]["value"]
|
||||
return_clauses.append(
|
||||
(
|
||||
f"values: {distinct_values},"
|
||||
f" distinct_count: {len(distinct_values)}"
|
||||
)
|
||||
)
|
||||
else:
|
||||
with_clauses.append(
|
||||
(
|
||||
f"collect(distinct substring(toString(n.`{prop_name}`)"
|
||||
f", 0, 50)) AS `{prop_name}_values`"
|
||||
)
|
||||
)
|
||||
return_clauses.append(f"values: `{prop_name}_values`")
|
||||
elif prop_type in [
|
||||
"INTEGER",
|
||||
"FLOAT",
|
||||
"DATE",
|
||||
"DATE_TIME",
|
||||
"LOCAL_DATE_TIME",
|
||||
]:
|
||||
if not prop_index:
|
||||
with_clauses.append(
|
||||
f"collect(distinct toString(n.`{prop_name}`)) "
|
||||
f"AS `{prop_name}_values`"
|
||||
)
|
||||
return_clauses.append(f"values: `{prop_name}_values`")
|
||||
else:
|
||||
with_clauses.append(
|
||||
f"min(n.`{prop_name}`) AS `{prop_name}_min`"
|
||||
)
|
||||
with_clauses.append(
|
||||
f"max(n.`{prop_name}`) AS `{prop_name}_max`"
|
||||
)
|
||||
with_clauses.append(
|
||||
f"count(distinct n.`{prop_name}`) AS `{prop_name}_distinct`"
|
||||
)
|
||||
return_clauses.append(
|
||||
(
|
||||
f"min: toString(`{prop_name}_min`), "
|
||||
f"max: toString(`{prop_name}_max`), "
|
||||
f"distinct_count: `{prop_name}_distinct`"
|
||||
)
|
||||
)
|
||||
|
||||
elif prop_type == "LIST":
|
||||
with_clauses.append(
|
||||
(
|
||||
f"min(size(n.`{prop_name}`)) AS `{prop_name}_size_min`, "
|
||||
f"max(size(n.`{prop_name}`)) AS `{prop_name}_size_max`"
|
||||
)
|
||||
)
|
||||
return_clauses.append(
|
||||
(
|
||||
f"min_size: `{prop_name}_size_min`, "
|
||||
f"max_size: `{prop_name}_size_max`"
|
||||
)
|
||||
)
|
||||
elif prop_type in ["BOOLEAN", "POINT", "DURATION"]:
|
||||
continue
|
||||
|
||||
output_dict[prop_name] = "{" + return_clauses.pop() + "}"
|
||||
|
||||
with_clause = "WITH " + ",\n ".join(with_clauses)
|
||||
return_clause = (
|
||||
"RETURN {"
|
||||
+ ", ".join(f"`{k}`: {v}" for k, v in output_dict.items())
|
||||
+ "} AS output"
|
||||
)
|
||||
|
||||
# Combine all parts of the Cypher query
|
||||
cypher_query = "\n".join([match_clause, with_clause, return_clause])
|
||||
return cypher_query
|
||||
@@ -0,0 +1,426 @@
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from langchain_core._api.deprecation import deprecated
|
||||
|
||||
|
||||
class NeptuneQueryException(Exception):
|
||||
"""Exception for the Neptune queries."""
|
||||
|
||||
def __init__(self, exception: Union[str, Dict]):
|
||||
if isinstance(exception, dict):
|
||||
self.message = exception["message"] if "message" in exception else "unknown"
|
||||
self.details = exception["details"] if "details" in exception else "unknown"
|
||||
else:
|
||||
self.message = exception
|
||||
self.details = "unknown"
|
||||
|
||||
def get_message(self) -> str:
|
||||
return self.message
|
||||
|
||||
def get_details(self) -> Any:
|
||||
return self.details
|
||||
|
||||
|
||||
class BaseNeptuneGraph(ABC):
|
||||
"""Abstract base class for Neptune."""
|
||||
|
||||
@property
|
||||
def get_schema(self) -> str:
|
||||
"""Return the schema of the Neptune database"""
|
||||
return self.schema
|
||||
|
||||
@abstractmethod
|
||||
def query(self, query: str, params: dict = {}) -> dict:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def _get_summary(self) -> Dict:
|
||||
raise NotImplementedError()
|
||||
|
||||
def _get_labels(self) -> Tuple[List[str], List[str]]:
|
||||
"""Get node and edge labels from the Neptune statistics summary"""
|
||||
summary = self._get_summary()
|
||||
n_labels = summary["nodeLabels"]
|
||||
e_labels = summary["edgeLabels"]
|
||||
return n_labels, e_labels
|
||||
|
||||
def _get_triples(self, e_labels: List[str]) -> List[str]:
|
||||
triple_query = """
|
||||
MATCH (a)-[e:`{e_label}`]->(b)
|
||||
WITH a,e,b LIMIT 3000
|
||||
RETURN DISTINCT labels(a) AS from, type(e) AS edge, labels(b) AS to
|
||||
LIMIT 10
|
||||
"""
|
||||
|
||||
triple_template = "(:`{a}`)-[:`{e}`]->(:`{b}`)"
|
||||
triple_schema = []
|
||||
for label in e_labels:
|
||||
q = triple_query.format(e_label=label)
|
||||
data = self.query(q)
|
||||
for d in data:
|
||||
triple = triple_template.format(
|
||||
a=d["from"][0], e=d["edge"], b=d["to"][0]
|
||||
)
|
||||
triple_schema.append(triple)
|
||||
|
||||
return triple_schema
|
||||
|
||||
def _get_node_properties(self, n_labels: List[str], types: Dict) -> List:
|
||||
node_properties_query = """
|
||||
MATCH (a:`{n_label}`)
|
||||
RETURN properties(a) AS props
|
||||
LIMIT 100
|
||||
"""
|
||||
node_properties = []
|
||||
for label in n_labels:
|
||||
q = node_properties_query.format(n_label=label)
|
||||
data = {"label": label, "properties": self.query(q)}
|
||||
s = set({})
|
||||
for p in data["properties"]:
|
||||
for k, v in p["props"].items():
|
||||
s.add((k, types[type(v).__name__]))
|
||||
|
||||
np = {
|
||||
"properties": [{"property": k, "type": v} for k, v in s],
|
||||
"labels": label,
|
||||
}
|
||||
node_properties.append(np)
|
||||
|
||||
return node_properties
|
||||
|
||||
def _get_edge_properties(self, e_labels: List[str], types: Dict[str, Any]) -> List:
|
||||
edge_properties_query = """
|
||||
MATCH ()-[e:`{e_label}`]->()
|
||||
RETURN properties(e) AS props
|
||||
LIMIT 100
|
||||
"""
|
||||
edge_properties = []
|
||||
for label in e_labels:
|
||||
q = edge_properties_query.format(e_label=label)
|
||||
data = {"label": label, "properties": self.query(q)}
|
||||
s = set({})
|
||||
for p in data["properties"]:
|
||||
for k, v in p["props"].items():
|
||||
s.add((k, types[type(v).__name__]))
|
||||
|
||||
ep = {
|
||||
"type": label,
|
||||
"properties": [{"property": k, "type": v} for k, v in s],
|
||||
}
|
||||
edge_properties.append(ep)
|
||||
|
||||
return edge_properties
|
||||
|
||||
def _refresh_schema(self) -> None:
|
||||
"""
|
||||
Refreshes the Neptune graph schema information.
|
||||
"""
|
||||
|
||||
types = {
|
||||
"str": "STRING",
|
||||
"float": "DOUBLE",
|
||||
"int": "INTEGER",
|
||||
"list": "LIST",
|
||||
"dict": "MAP",
|
||||
"bool": "BOOLEAN",
|
||||
}
|
||||
n_labels, e_labels = self._get_labels()
|
||||
triple_schema = self._get_triples(e_labels)
|
||||
node_properties = self._get_node_properties(n_labels, types)
|
||||
edge_properties = self._get_edge_properties(e_labels, types)
|
||||
|
||||
self.schema = f"""
|
||||
Node properties are the following:
|
||||
{node_properties}
|
||||
Relationship properties are the following:
|
||||
{edge_properties}
|
||||
The relationships are the following:
|
||||
{triple_schema}
|
||||
"""
|
||||
|
||||
|
||||
@deprecated(
|
||||
since="0.3.15",
|
||||
removal="1.0",
|
||||
alternative_import="langchain_aws.NeptuneAnalyticsGraph",
|
||||
)
|
||||
class NeptuneAnalyticsGraph(BaseNeptuneGraph):
|
||||
"""Neptune Analytics wrapper for graph operations.
|
||||
|
||||
Parameters:
|
||||
client: optional boto3 Neptune client
|
||||
credentials_profile_name: optional AWS profile name
|
||||
region_name: optional AWS region, e.g., us-west-2
|
||||
graph_identifier: the graph identifier for a Neptune Analytics graph
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
graph = NeptuneAnalyticsGraph(
|
||||
graph_identifier='<my-graph-id>'
|
||||
)
|
||||
|
||||
*Security note*: Make sure that the database connection uses credentials
|
||||
that are narrowly-scoped to only include necessary permissions.
|
||||
Failure to do so may result in data corruption or loss, since the calling
|
||||
code may attempt commands that would result in deletion, mutation
|
||||
of data if appropriately prompted or reading sensitive data if such
|
||||
data is present in the database.
|
||||
The best way to guard against such negative outcomes is to (as appropriate)
|
||||
limit the permissions granted to the credentials used with this tool.
|
||||
|
||||
See https://python.langchain.com/docs/security for more information.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
graph_identifier: str,
|
||||
client: Any = None,
|
||||
credentials_profile_name: Optional[str] = None,
|
||||
region_name: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Create a new Neptune Analytics graph wrapper instance."""
|
||||
|
||||
try:
|
||||
if client is not None:
|
||||
self.client = client
|
||||
else:
|
||||
import boto3
|
||||
|
||||
if credentials_profile_name is not None:
|
||||
session = boto3.Session(profile_name=credentials_profile_name)
|
||||
else:
|
||||
# use default credentials
|
||||
session = boto3.Session()
|
||||
|
||||
self.graph_identifier = graph_identifier
|
||||
|
||||
if region_name:
|
||||
self.client = session.client(
|
||||
"neptune-graph", region_name=region_name
|
||||
)
|
||||
else:
|
||||
self.client = session.client("neptune-graph")
|
||||
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import boto3 python package. "
|
||||
"Please install it with `pip install boto3`."
|
||||
)
|
||||
except Exception as e:
|
||||
if type(e).__name__ == "UnknownServiceError":
|
||||
raise ImportError(
|
||||
"NeptuneGraph requires a boto3 version 1.34.40 or greater."
|
||||
"Please install it with `pip install -U boto3`."
|
||||
) from e
|
||||
else:
|
||||
raise ValueError(
|
||||
"Could not load credentials to authenticate with AWS client. "
|
||||
"Please check that credentials in the specified "
|
||||
"profile name are valid."
|
||||
) from e
|
||||
|
||||
try:
|
||||
self._refresh_schema()
|
||||
except Exception as e:
|
||||
raise NeptuneQueryException(
|
||||
{
|
||||
"message": "Could not get schema for Neptune database",
|
||||
"detail": str(e),
|
||||
}
|
||||
)
|
||||
|
||||
def query(self, query: str, params: dict = {}) -> Dict[str, Any]:
|
||||
"""Query Neptune database."""
|
||||
try:
|
||||
resp = self.client.execute_query(
|
||||
graphIdentifier=self.graph_identifier,
|
||||
queryString=query,
|
||||
parameters=params,
|
||||
language="OPEN_CYPHER",
|
||||
)
|
||||
return json.loads(resp["payload"].read().decode("UTF-8"))["results"]
|
||||
except Exception as e:
|
||||
raise NeptuneQueryException(
|
||||
{
|
||||
"message": "An error occurred while executing the query.",
|
||||
"details": str(e),
|
||||
}
|
||||
)
|
||||
|
||||
def _get_summary(self) -> Dict:
|
||||
try:
|
||||
response = self.client.get_graph_summary(
|
||||
graphIdentifier=self.graph_identifier, mode="detailed"
|
||||
)
|
||||
except Exception as e:
|
||||
raise NeptuneQueryException(
|
||||
{
|
||||
"message": ("Summary API error occurred on Neptune Analytics"),
|
||||
"details": str(e),
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
summary = response["graphSummary"]
|
||||
except Exception:
|
||||
raise NeptuneQueryException(
|
||||
{
|
||||
"message": "Summary API did not return a valid response.",
|
||||
"details": response.content.decode(),
|
||||
}
|
||||
)
|
||||
else:
|
||||
return summary
|
||||
|
||||
|
||||
@deprecated(
|
||||
since="0.3.15",
|
||||
removal="1.0",
|
||||
alternative_import="langchain_aws.NeptuneGraph",
|
||||
)
|
||||
class NeptuneGraph(BaseNeptuneGraph):
|
||||
"""Neptune wrapper for graph operations.
|
||||
|
||||
Parameters:
|
||||
host: endpoint for the database instance
|
||||
port: port number for the database instance, default is 8182
|
||||
use_https: whether to use secure connection, default is True
|
||||
client: optional boto3 Neptune client
|
||||
credentials_profile_name: optional AWS profile name
|
||||
region_name: optional AWS region, e.g., us-west-2
|
||||
sign: optional, whether to sign the request payload, default is True
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
graph = NeptuneGraph(
|
||||
host='<my-cluster>',
|
||||
port=8182
|
||||
)
|
||||
|
||||
*Security note*: Make sure that the database connection uses credentials
|
||||
that are narrowly-scoped to only include necessary permissions.
|
||||
Failure to do so may result in data corruption or loss, since the calling
|
||||
code may attempt commands that would result in deletion, mutation
|
||||
of data if appropriately prompted or reading sensitive data if such
|
||||
data is present in the database.
|
||||
The best way to guard against such negative outcomes is to (as appropriate)
|
||||
limit the permissions granted to the credentials used with this tool.
|
||||
|
||||
See https://python.langchain.com/docs/security for more information.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str,
|
||||
port: int = 8182,
|
||||
use_https: bool = True,
|
||||
client: Any = None,
|
||||
credentials_profile_name: Optional[str] = None,
|
||||
region_name: Optional[str] = None,
|
||||
sign: bool = True,
|
||||
) -> None:
|
||||
"""Create a new Neptune graph wrapper instance."""
|
||||
|
||||
try:
|
||||
if client is not None:
|
||||
self.client = client
|
||||
else:
|
||||
import boto3
|
||||
|
||||
if credentials_profile_name is not None:
|
||||
session = boto3.Session(profile_name=credentials_profile_name)
|
||||
else:
|
||||
# use default credentials
|
||||
session = boto3.Session()
|
||||
|
||||
client_params = {}
|
||||
if region_name:
|
||||
client_params["region_name"] = region_name
|
||||
|
||||
protocol = "https" if use_https else "http"
|
||||
|
||||
client_params["endpoint_url"] = f"{protocol}://{host}:{port}"
|
||||
|
||||
if sign:
|
||||
self.client = session.client("neptunedata", **client_params)
|
||||
else:
|
||||
from botocore import UNSIGNED
|
||||
from botocore.config import Config
|
||||
|
||||
self.client = session.client(
|
||||
"neptunedata",
|
||||
**client_params,
|
||||
config=Config(signature_version=UNSIGNED),
|
||||
)
|
||||
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import boto3 python package. "
|
||||
"Please install it with `pip install boto3`."
|
||||
)
|
||||
except Exception as e:
|
||||
if type(e).__name__ == "UnknownServiceError":
|
||||
raise ImportError(
|
||||
"NeptuneGraph requires a boto3 version 1.28.38 or greater."
|
||||
"Please install it with `pip install -U boto3`."
|
||||
) from e
|
||||
else:
|
||||
raise ValueError(
|
||||
"Could not load credentials to authenticate with AWS client. "
|
||||
"Please check that credentials in the specified "
|
||||
"profile name are valid."
|
||||
) from e
|
||||
|
||||
try:
|
||||
self._refresh_schema()
|
||||
except Exception as e:
|
||||
raise NeptuneQueryException(
|
||||
{
|
||||
"message": "Could not get schema for Neptune database",
|
||||
"detail": str(e),
|
||||
}
|
||||
)
|
||||
|
||||
def query(self, query: str, params: dict = {}) -> Dict[str, Any]:
|
||||
"""Query Neptune database."""
|
||||
try:
|
||||
return self.client.execute_open_cypher_query(openCypherQuery=query)[
|
||||
"results"
|
||||
]
|
||||
except Exception as e:
|
||||
raise NeptuneQueryException(
|
||||
{
|
||||
"message": "An error occurred while executing the query.",
|
||||
"details": str(e),
|
||||
}
|
||||
)
|
||||
|
||||
def _get_summary(self) -> Dict:
|
||||
try:
|
||||
response = self.client.get_propertygraph_summary()
|
||||
except Exception as e:
|
||||
raise NeptuneQueryException(
|
||||
{
|
||||
"message": (
|
||||
"Summary API is not available for this instance of Neptune,"
|
||||
"ensure the engine version is >=1.2.1.0"
|
||||
),
|
||||
"details": str(e),
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
summary = response["payload"]["graphSummary"]
|
||||
except Exception:
|
||||
raise NeptuneQueryException(
|
||||
{
|
||||
"message": "Summary API did not return a valid response.",
|
||||
"details": response.content.decode(),
|
||||
}
|
||||
)
|
||||
else:
|
||||
return summary
|
||||
@@ -0,0 +1,302 @@
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, Dict, Optional, Sequence
|
||||
|
||||
import requests
|
||||
from langchain_core._api.deprecation import deprecated
|
||||
|
||||
# Query to find OWL datatype properties
|
||||
DTPROP_QUERY = """
|
||||
SELECT DISTINCT ?elem
|
||||
WHERE {
|
||||
?elem a owl:DatatypeProperty .
|
||||
}
|
||||
"""
|
||||
|
||||
# Query to find OWL object properties
|
||||
OPROP_QUERY = """
|
||||
SELECT DISTINCT ?elem
|
||||
WHERE {
|
||||
?elem a owl:ObjectProperty .
|
||||
}
|
||||
"""
|
||||
|
||||
ELEM_TYPES = {
|
||||
"classes": None,
|
||||
"rels": None,
|
||||
"dtprops": DTPROP_QUERY,
|
||||
"oprops": OPROP_QUERY,
|
||||
}
|
||||
|
||||
|
||||
@deprecated(
|
||||
since="0.3.15",
|
||||
removal="1.0",
|
||||
alternative_import="langchain_aws.NeptuneRdfGraph",
|
||||
)
|
||||
class NeptuneRdfGraph:
|
||||
"""Neptune wrapper for RDF graph operations.
|
||||
|
||||
Args:
|
||||
host: endpoint for the database instance
|
||||
port: port number for the database instance, default is 8182
|
||||
use_iam_auth: boolean indicating IAM auth is enabled in Neptune cluster
|
||||
use_https: whether to use secure connection, default is True
|
||||
client: optional boto3 Neptune client
|
||||
credentials_profile_name: optional AWS profile name
|
||||
region_name: optional AWS region, e.g., us-west-2
|
||||
service: optional service name, default is neptunedata
|
||||
sign: optional, whether to sign the request payload, default is True
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
graph = NeptuneRdfGraph(
|
||||
host='<SPARQL host'>,
|
||||
port=<SPARQL port>
|
||||
)
|
||||
schema = graph.get_schema()
|
||||
|
||||
OR
|
||||
graph = NeptuneRdfGraph(
|
||||
host='<SPARQL host'>,
|
||||
port=<SPARQL port>
|
||||
)
|
||||
schema_elem = graph.get_schema_elements()
|
||||
#... change schema_elements ...
|
||||
graph.load_schema(schema_elem)
|
||||
|
||||
*Security note*: Make sure that the database connection uses credentials
|
||||
that are narrowly-scoped to only include necessary permissions.
|
||||
Failure to do so may result in data corruption or loss, since the calling
|
||||
code may attempt commands that would result in deletion, mutation
|
||||
of data if appropriately prompted or reading sensitive data if such
|
||||
data is present in the database.
|
||||
The best way to guard against such negative outcomes is to (as appropriate)
|
||||
limit the permissions granted to the credentials used with this tool.
|
||||
|
||||
See https://python.langchain.com/docs/security for more information.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str,
|
||||
port: int = 8182,
|
||||
use_https: bool = True,
|
||||
use_iam_auth: bool = False,
|
||||
client: Any = None,
|
||||
credentials_profile_name: Optional[str] = None,
|
||||
region_name: Optional[str] = None,
|
||||
service: str = "neptunedata",
|
||||
sign: bool = True,
|
||||
) -> None:
|
||||
self.use_iam_auth = use_iam_auth
|
||||
self.region_name = region_name
|
||||
self.query_endpoint = f"https://{host}:{port}/sparql"
|
||||
|
||||
try:
|
||||
if client is not None:
|
||||
self.client = client
|
||||
else:
|
||||
import boto3
|
||||
|
||||
if credentials_profile_name is not None:
|
||||
self.session = boto3.Session(profile_name=credentials_profile_name)
|
||||
else:
|
||||
# use default credentials
|
||||
self.session = boto3.Session()
|
||||
|
||||
client_params = {}
|
||||
if region_name:
|
||||
client_params["region_name"] = region_name
|
||||
|
||||
protocol = "https" if use_https else "http"
|
||||
|
||||
client_params["endpoint_url"] = f"{protocol}://{host}:{port}"
|
||||
|
||||
if sign:
|
||||
self.client = self.session.client(service, **client_params)
|
||||
else:
|
||||
from botocore import UNSIGNED
|
||||
from botocore.config import Config
|
||||
|
||||
self.client = self.session.client(
|
||||
service,
|
||||
**client_params,
|
||||
config=Config(signature_version=UNSIGNED),
|
||||
)
|
||||
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import boto3 python package. "
|
||||
"Please install it with `pip install boto3`."
|
||||
)
|
||||
except Exception as e:
|
||||
if type(e).__name__ == "UnknownServiceError":
|
||||
raise ImportError(
|
||||
"NeptuneGraph requires a boto3 version 1.28.38 or greater."
|
||||
"Please install it with `pip install -U boto3`."
|
||||
) from e
|
||||
else:
|
||||
raise ValueError(
|
||||
"Could not load credentials to authenticate with AWS client. "
|
||||
"Please check that credentials in the specified "
|
||||
"profile name are valid."
|
||||
) from e
|
||||
|
||||
# Set schema
|
||||
self.schema = ""
|
||||
self.schema_elements: Dict[str, Any] = {}
|
||||
self._refresh_schema()
|
||||
|
||||
@property
|
||||
def get_schema(self) -> str:
|
||||
"""
|
||||
Returns the schema of the graph database.
|
||||
"""
|
||||
return self.schema
|
||||
|
||||
@property
|
||||
def get_schema_elements(self) -> Dict[str, Any]:
|
||||
return self.schema_elements
|
||||
|
||||
def get_summary(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Obtain Neptune statistical summary of classes and predicates in the graph.
|
||||
"""
|
||||
return self.client.get_rdf_graph_summary(mode="detailed")
|
||||
|
||||
def query(
|
||||
self,
|
||||
query: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Run Neptune query.
|
||||
"""
|
||||
request_data = {"query": query}
|
||||
data = request_data
|
||||
request_hdr = None
|
||||
|
||||
if self.use_iam_auth:
|
||||
credentials = self.session.get_credentials()
|
||||
credentials = credentials.get_frozen_credentials()
|
||||
access_key = credentials.access_key
|
||||
secret_key = credentials.secret_key
|
||||
service = "neptune-db"
|
||||
session_token = credentials.token
|
||||
params = None
|
||||
creds = SimpleNamespace(
|
||||
access_key=access_key,
|
||||
secret_key=secret_key,
|
||||
token=session_token,
|
||||
region=self.region_name,
|
||||
)
|
||||
from botocore.awsrequest import AWSRequest
|
||||
|
||||
request = AWSRequest(
|
||||
method="POST", url=self.query_endpoint, data=data, params=params
|
||||
)
|
||||
from botocore.auth import SigV4Auth
|
||||
|
||||
SigV4Auth(creds, service, self.region_name).add_auth(request)
|
||||
request.headers["Content-Type"] = "application/x-www-form-urlencoded"
|
||||
request_hdr = request.headers
|
||||
else:
|
||||
request_hdr = {}
|
||||
request_hdr["Content-Type"] = "application/x-www-form-urlencoded"
|
||||
|
||||
queryres = requests.request(
|
||||
method="POST", url=self.query_endpoint, headers=request_hdr, data=data
|
||||
)
|
||||
json_resp = json.loads(queryres.text)
|
||||
return json_resp
|
||||
|
||||
def load_schema(self, schema_elements: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Generates and sets schema from schema_elements. Helpful in
|
||||
cases where introspected schema needs pruning.
|
||||
"""
|
||||
|
||||
elem_str = {}
|
||||
for elem in ELEM_TYPES:
|
||||
res_list = []
|
||||
for elem_rec in schema_elements[elem]:
|
||||
uri = elem_rec["uri"]
|
||||
local = elem_rec["local"]
|
||||
res_str = f"<{uri}> ({local})"
|
||||
res_list.append(res_str)
|
||||
elem_str[elem] = ", ".join(res_list)
|
||||
|
||||
self.schema = (
|
||||
"In the following, each IRI is followed by the local name and "
|
||||
"optionally its description in parentheses. \n"
|
||||
"The graph supports the following node types:\n"
|
||||
f"{elem_str['classes']}\n"
|
||||
"The graph supports the following relationships:\n"
|
||||
f"{elem_str['rels']}\n"
|
||||
"The graph supports the following OWL object properties:\n"
|
||||
f"{elem_str['dtprops']}\n"
|
||||
"The graph supports the following OWL data properties:\n"
|
||||
f"{elem_str['oprops']}"
|
||||
)
|
||||
|
||||
def _get_local_name(self, iri: str) -> Sequence[str]:
|
||||
"""
|
||||
Split IRI into prefix and local
|
||||
"""
|
||||
if "#" in iri:
|
||||
tokens = iri.split("#")
|
||||
return [f"{tokens[0]}#", tokens[-1]]
|
||||
elif "/" in iri:
|
||||
tokens = iri.split("/")
|
||||
return [f"{'/'.join(tokens[0 : len(tokens) - 1])}/", tokens[-1]]
|
||||
else:
|
||||
raise ValueError(f"Unexpected IRI '{iri}', contains neither '#' nor '/'.")
|
||||
|
||||
def _refresh_schema(self) -> None:
|
||||
"""
|
||||
Query Neptune to introspect schema.
|
||||
"""
|
||||
self.schema_elements["distinct_prefixes"] = {}
|
||||
|
||||
# get summary and build list of classes and rels
|
||||
summary = self.get_summary()
|
||||
reslist = []
|
||||
for c in summary["payload"]["graphSummary"]["classes"]:
|
||||
uri = c
|
||||
tokens = self._get_local_name(uri)
|
||||
elem_record = {"uri": uri, "local": tokens[1]}
|
||||
reslist.append(elem_record)
|
||||
if tokens[0] not in self.schema_elements["distinct_prefixes"]:
|
||||
self.schema_elements["distinct_prefixes"][tokens[0]] = "y"
|
||||
self.schema_elements["classes"] = reslist
|
||||
|
||||
reslist = []
|
||||
for r in summary["payload"]["graphSummary"]["predicates"]:
|
||||
for p in r:
|
||||
uri = p
|
||||
tokens = self._get_local_name(uri)
|
||||
elem_record = {"uri": uri, "local": tokens[1]}
|
||||
reslist.append(elem_record)
|
||||
if tokens[0] not in self.schema_elements["distinct_prefixes"]:
|
||||
self.schema_elements["distinct_prefixes"][tokens[0]] = "y"
|
||||
self.schema_elements["rels"] = reslist
|
||||
|
||||
# get dtprops and oprops too
|
||||
for elem in ELEM_TYPES:
|
||||
q = ELEM_TYPES.get(elem)
|
||||
if not q:
|
||||
continue
|
||||
items = self.query(q)
|
||||
reslist = []
|
||||
for r in items["results"]["bindings"]:
|
||||
uri = r["elem"]["value"]
|
||||
tokens = self._get_local_name(uri)
|
||||
elem_record = {"uri": uri, "local": tokens[1]}
|
||||
reslist.append(elem_record)
|
||||
if tokens[0] not in self.schema_elements["distinct_prefixes"]:
|
||||
self.schema_elements["distinct_prefixes"][tokens[0]] = "y"
|
||||
|
||||
self.schema_elements[elem] = reslist
|
||||
|
||||
self.load_schema(self.schema_elements)
|
||||
@@ -0,0 +1,218 @@
|
||||
"""Networkx wrapper for graph operations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, List, NamedTuple, Optional, Tuple
|
||||
|
||||
KG_TRIPLE_DELIMITER = "<|>"
|
||||
|
||||
|
||||
class KnowledgeTriple(NamedTuple):
|
||||
"""Knowledge triple in the graph."""
|
||||
|
||||
subject: str
|
||||
predicate: str
|
||||
object_: str
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, triple_string: str) -> "KnowledgeTriple":
|
||||
"""Create a KnowledgeTriple from a string."""
|
||||
subject, predicate, object_ = triple_string.strip().split(", ")
|
||||
subject = subject[1:]
|
||||
object_ = object_[:-1]
|
||||
return cls(subject, predicate, object_)
|
||||
|
||||
|
||||
def parse_triples(knowledge_str: str) -> List[KnowledgeTriple]:
|
||||
"""Parse knowledge triples from the knowledge string."""
|
||||
knowledge_str = knowledge_str.strip()
|
||||
if not knowledge_str or knowledge_str == "NONE":
|
||||
return []
|
||||
triple_strs = knowledge_str.split(KG_TRIPLE_DELIMITER)
|
||||
results = []
|
||||
for triple_str in triple_strs:
|
||||
try:
|
||||
kg_triple = KnowledgeTriple.from_string(triple_str)
|
||||
except ValueError:
|
||||
continue
|
||||
results.append(kg_triple)
|
||||
return results
|
||||
|
||||
|
||||
def get_entities(entity_str: str) -> List[str]:
|
||||
"""Extract entities from entity string."""
|
||||
if entity_str.strip() == "NONE":
|
||||
return []
|
||||
else:
|
||||
return [w.strip() for w in entity_str.split(",")]
|
||||
|
||||
|
||||
class NetworkxEntityGraph:
|
||||
"""Networkx wrapper for entity graph operations.
|
||||
|
||||
*Security note*: Make sure that the database connection uses credentials
|
||||
that are narrowly-scoped to only include necessary permissions.
|
||||
Failure to do so may result in data corruption or loss, since the calling
|
||||
code may attempt commands that would result in deletion, mutation
|
||||
of data if appropriately prompted or reading sensitive data if such
|
||||
data is present in the database.
|
||||
The best way to guard against such negative outcomes is to (as appropriate)
|
||||
limit the permissions granted to the credentials used with this tool.
|
||||
|
||||
See https://python.langchain.com/docs/security for more information.
|
||||
"""
|
||||
|
||||
def __init__(self, graph: Optional[Any] = None) -> None:
|
||||
"""Create a new graph."""
|
||||
try:
|
||||
import networkx as nx
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import networkx python package. "
|
||||
"Please install it with `pip install networkx`."
|
||||
)
|
||||
if graph is not None:
|
||||
if not isinstance(graph, nx.DiGraph):
|
||||
raise ValueError("Passed in graph is not of correct shape")
|
||||
self._graph = graph
|
||||
else:
|
||||
self._graph = nx.DiGraph()
|
||||
|
||||
@classmethod
|
||||
def from_gml(cls, gml_path: str) -> NetworkxEntityGraph:
|
||||
try:
|
||||
import networkx as nx
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import networkx python package. "
|
||||
"Please install it with `pip install networkx`."
|
||||
)
|
||||
graph = nx.read_gml(gml_path)
|
||||
return cls(graph)
|
||||
|
||||
def add_triple(self, knowledge_triple: KnowledgeTriple) -> None:
|
||||
"""Add a triple to the graph."""
|
||||
# Creates nodes if they don't exist
|
||||
# Overwrites existing edges
|
||||
if not self._graph.has_node(knowledge_triple.subject):
|
||||
self._graph.add_node(knowledge_triple.subject)
|
||||
if not self._graph.has_node(knowledge_triple.object_):
|
||||
self._graph.add_node(knowledge_triple.object_)
|
||||
self._graph.add_edge(
|
||||
knowledge_triple.subject,
|
||||
knowledge_triple.object_,
|
||||
relation=knowledge_triple.predicate,
|
||||
)
|
||||
|
||||
def delete_triple(self, knowledge_triple: KnowledgeTriple) -> None:
|
||||
"""Delete a triple from the graph."""
|
||||
if self._graph.has_edge(knowledge_triple.subject, knowledge_triple.object_):
|
||||
self._graph.remove_edge(knowledge_triple.subject, knowledge_triple.object_)
|
||||
|
||||
def get_triples(self) -> List[Tuple[str, str, str]]:
|
||||
"""Get all triples in the graph."""
|
||||
return [(u, v, d["relation"]) for u, v, d in self._graph.edges(data=True)]
|
||||
|
||||
def get_entity_knowledge(self, entity: str, depth: int = 1) -> List[str]:
|
||||
"""Get information about an entity."""
|
||||
import networkx as nx
|
||||
|
||||
# TODO: Have more information-specific retrieval methods
|
||||
if not self._graph.has_node(entity):
|
||||
return []
|
||||
|
||||
results = []
|
||||
for src, sink in nx.dfs_edges(self._graph, entity, depth_limit=depth):
|
||||
relation = self._graph[src][sink]["relation"]
|
||||
results.append(f"{src} {relation} {sink}")
|
||||
return results
|
||||
|
||||
def write_to_gml(self, path: str) -> None:
|
||||
import networkx as nx
|
||||
|
||||
nx.write_gml(self._graph, path)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear the graph."""
|
||||
self._graph.clear()
|
||||
|
||||
def clear_edges(self) -> None:
|
||||
"""Clear the graph edges."""
|
||||
self._graph.clear_edges()
|
||||
|
||||
def add_node(self, node: str) -> None:
|
||||
"""Add node in the graph."""
|
||||
self._graph.add_node(node)
|
||||
|
||||
def remove_node(self, node: str) -> None:
|
||||
"""Remove node from the graph."""
|
||||
if self._graph.has_node(node):
|
||||
self._graph.remove_node(node)
|
||||
|
||||
def has_node(self, node: str) -> bool:
|
||||
"""Return if graph has the given node."""
|
||||
return self._graph.has_node(node)
|
||||
|
||||
def remove_edge(self, source_node: str, destination_node: str) -> None:
|
||||
"""Remove edge from the graph."""
|
||||
self._graph.remove_edge(source_node, destination_node)
|
||||
|
||||
def has_edge(self, source_node: str, destination_node: str) -> bool:
|
||||
"""Return if graph has an edge between the given nodes."""
|
||||
if self._graph.has_node(source_node) and self._graph.has_node(destination_node):
|
||||
return self._graph.has_edge(source_node, destination_node)
|
||||
else:
|
||||
return False
|
||||
|
||||
def get_neighbors(self, node: str) -> List[str]:
|
||||
"""Return the neighbor nodes of the given node."""
|
||||
return self._graph.neighbors(node)
|
||||
|
||||
def get_number_of_nodes(self) -> int:
|
||||
"""Get number of nodes in the graph."""
|
||||
return self._graph.number_of_nodes()
|
||||
|
||||
def get_topological_sort(self) -> List[str]:
|
||||
"""Get a list of entity names in the graph sorted by causal dependence."""
|
||||
import networkx as nx
|
||||
|
||||
return list(nx.topological_sort(self._graph))
|
||||
|
||||
def draw_graphviz(self, **kwargs: Any) -> None:
|
||||
"""
|
||||
Provides better drawing
|
||||
|
||||
Usage in a jupyter notebook:
|
||||
|
||||
>>> from IPython.display import SVG
|
||||
>>> self.draw_graphviz_svg(layout="dot", filename="web.svg")
|
||||
>>> SVG('web.svg')
|
||||
"""
|
||||
from networkx.drawing.nx_agraph import to_agraph
|
||||
|
||||
try:
|
||||
import pygraphviz # noqa: F401
|
||||
|
||||
except ImportError as e:
|
||||
if e.name == "_graphviz":
|
||||
"""
|
||||
>>> e.msg # pygraphviz throws this error
|
||||
ImportError: libcgraph.so.6: cannot open shared object file
|
||||
"""
|
||||
raise ImportError(
|
||||
"Could not import graphviz debian package. "
|
||||
"Please install it with:"
|
||||
"`sudo apt-get update`"
|
||||
"`sudo apt-get install graphviz graphviz-dev`"
|
||||
)
|
||||
else:
|
||||
raise ImportError(
|
||||
"Could not import pygraphviz python package. "
|
||||
"Please install it with:"
|
||||
"`pip install pygraphviz`."
|
||||
)
|
||||
|
||||
graph = to_agraph(self._graph) # --> pygraphviz.agraph.AGraph
|
||||
# pygraphviz.github.io/documentation/stable/tutorial.html#layout-and-drawing
|
||||
graph.layout(prog=kwargs.get("prog", "dot"))
|
||||
graph.draw(kwargs.get("path", "graph.svg"))
|
||||
@@ -0,0 +1,216 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
List,
|
||||
Optional,
|
||||
Union,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import rdflib
|
||||
|
||||
|
||||
class OntotextGraphDBGraph:
|
||||
"""Ontotext GraphDB https://graphdb.ontotext.com/ wrapper for graph operations.
|
||||
|
||||
*Security note*: Make sure that the database connection uses credentials
|
||||
that are narrowly-scoped to only include necessary permissions.
|
||||
Failure to do so may result in data corruption or loss, since the calling
|
||||
code may attempt commands that would result in deletion, mutation
|
||||
of data if appropriately prompted or reading sensitive data if such
|
||||
data is present in the database.
|
||||
The best way to guard against such negative outcomes is to (as appropriate)
|
||||
limit the permissions granted to the credentials used with this tool.
|
||||
|
||||
See https://python.langchain.com/docs/security for more information.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
query_endpoint: str,
|
||||
query_ontology: Optional[str] = None,
|
||||
local_file: Optional[str] = None,
|
||||
local_file_format: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Set up the GraphDB wrapper
|
||||
|
||||
:param query_endpoint: SPARQL endpoint for queries, read access
|
||||
|
||||
If GraphDB is secured,
|
||||
set the environment variables 'GRAPHDB_USERNAME' and 'GRAPHDB_PASSWORD'.
|
||||
|
||||
:param query_ontology: a `CONSTRUCT` query that is executed
|
||||
on the SPARQL endpoint and returns the KG schema statements
|
||||
Example:
|
||||
'CONSTRUCT {?s ?p ?o} FROM <https://example.com/ontology/> WHERE {?s ?p ?o}'
|
||||
Currently, DESCRIBE queries like
|
||||
'PREFIX onto: <https://example.com/ontology/>
|
||||
PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
|
||||
DESCRIBE ?term WHERE {
|
||||
?term rdfs:isDefinedBy onto:
|
||||
}'
|
||||
are not supported, because DESCRIBE returns
|
||||
the Symmetric Concise Bounded Description (SCBD),
|
||||
i.e. also the incoming class links.
|
||||
In case of large graphs with a million of instances, this is not efficient.
|
||||
Check https://github.com/eclipse-rdf4j/rdf4j/issues/4857
|
||||
|
||||
:param local_file: a local RDF ontology file.
|
||||
Supported RDF formats:
|
||||
Turtle, RDF/XML, JSON-LD, N-Triples, Notation-3, Trig, Trix, N-Quads.
|
||||
If the rdf format can't be determined from the file extension,
|
||||
pass explicitly the rdf format in `local_file_format` param.
|
||||
|
||||
:param local_file_format: Used if the rdf format can't be determined
|
||||
from the local file extension.
|
||||
One of "json-ld", "xml", "n3", "turtle", "nt", "trig", "nquads", "trix"
|
||||
|
||||
Either `query_ontology` or `local_file` should be passed.
|
||||
"""
|
||||
|
||||
if query_ontology and local_file:
|
||||
raise ValueError("Both file and query provided. Only one is allowed.")
|
||||
|
||||
if not query_ontology and not local_file:
|
||||
raise ValueError("Neither file nor query provided. One is required.")
|
||||
|
||||
try:
|
||||
import rdflib
|
||||
from rdflib.plugins.stores import sparqlstore
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import rdflib python package. "
|
||||
"Please install it with `pip install rdflib`."
|
||||
)
|
||||
|
||||
auth = self._get_auth()
|
||||
store = sparqlstore.SPARQLStore(auth=auth)
|
||||
store.open(query_endpoint)
|
||||
|
||||
self.graph = rdflib.Graph(store, identifier=None, bind_namespaces="none")
|
||||
self._check_connectivity()
|
||||
|
||||
ontology_schema_graph: "rdflib.Graph"
|
||||
if local_file:
|
||||
ontology_schema_graph = self._load_ontology_schema_from_file(
|
||||
local_file,
|
||||
local_file_format,
|
||||
)
|
||||
else:
|
||||
self._validate_user_query(query_ontology) # type: ignore[arg-type]
|
||||
ontology_schema_graph = self._load_ontology_schema_with_query(
|
||||
query_ontology # type: ignore[arg-type]
|
||||
)
|
||||
self.schema = ontology_schema_graph.serialize(format="turtle")
|
||||
|
||||
@staticmethod
|
||||
def _get_auth() -> Union[tuple, None]:
|
||||
"""
|
||||
Returns the basic authentication configuration
|
||||
"""
|
||||
username = os.environ.get("GRAPHDB_USERNAME", None)
|
||||
password = os.environ.get("GRAPHDB_PASSWORD", None)
|
||||
|
||||
if username:
|
||||
if not password:
|
||||
raise ValueError(
|
||||
"Environment variable 'GRAPHDB_USERNAME' is set, "
|
||||
"but 'GRAPHDB_PASSWORD' is not set."
|
||||
)
|
||||
else:
|
||||
return username, password
|
||||
return None
|
||||
|
||||
def _check_connectivity(self) -> None:
|
||||
"""
|
||||
Executes a simple `ASK` query to check connectivity
|
||||
"""
|
||||
try:
|
||||
self.graph.query("ASK { ?s ?p ?o }")
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
"Could not query the provided endpoint. "
|
||||
"Please, check, if the value of the provided "
|
||||
"query_endpoint points to the right repository. "
|
||||
"If GraphDB is secured, please, "
|
||||
"make sure that the environment variables "
|
||||
"'GRAPHDB_USERNAME' and 'GRAPHDB_PASSWORD' are set."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _load_ontology_schema_from_file(
|
||||
local_file: str, local_file_format: Optional[str] = None
|
||||
) -> "rdflib.ConjunctiveGraph":
|
||||
"""
|
||||
Parse the ontology schema statements from the provided file
|
||||
"""
|
||||
import rdflib
|
||||
|
||||
if not os.path.exists(local_file):
|
||||
raise FileNotFoundError(f"File {local_file} does not exist.")
|
||||
if not os.access(local_file, os.R_OK):
|
||||
raise PermissionError(f"Read permission for {local_file} is restricted")
|
||||
graph = rdflib.ConjunctiveGraph()
|
||||
try:
|
||||
graph.parse(local_file, format=local_file_format)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid file format for {local_file} : ", e)
|
||||
return graph
|
||||
|
||||
@staticmethod
|
||||
def _validate_user_query(query_ontology: str) -> None:
|
||||
"""
|
||||
Validate the query is a valid SPARQL CONSTRUCT query
|
||||
"""
|
||||
from pyparsing import ParseException
|
||||
from rdflib.plugins.sparql import prepareQuery
|
||||
|
||||
if not isinstance(query_ontology, str):
|
||||
raise TypeError("Ontology query must be provided as string.")
|
||||
try:
|
||||
parsed_query = prepareQuery(query_ontology)
|
||||
except ParseException as e:
|
||||
raise ValueError("Ontology query is not a valid SPARQL query.", e)
|
||||
|
||||
if parsed_query.algebra.name != "ConstructQuery":
|
||||
raise ValueError(
|
||||
"Invalid query type. Only CONSTRUCT queries are supported."
|
||||
)
|
||||
|
||||
def _load_ontology_schema_with_query(self, query: str) -> "rdflib.Graph":
|
||||
"""
|
||||
Execute the query for collecting the ontology schema statements
|
||||
"""
|
||||
from rdflib.exceptions import ParserError
|
||||
|
||||
try:
|
||||
results = self.graph.query(query)
|
||||
except ParserError as e:
|
||||
raise ValueError(f"Generated SPARQL statement is invalid\n{e}")
|
||||
|
||||
if not results.graph:
|
||||
raise ValueError("Missing graph in results.")
|
||||
|
||||
return results.graph
|
||||
|
||||
@property
|
||||
def get_schema(self) -> str:
|
||||
"""
|
||||
Returns the schema of the graph database in turtle format
|
||||
"""
|
||||
return self.schema
|
||||
|
||||
def query(
|
||||
self,
|
||||
query: str,
|
||||
) -> List[rdflib.query.ResultRow]:
|
||||
"""
|
||||
Query the graph.
|
||||
"""
|
||||
from rdflib.query import ResultRow
|
||||
|
||||
res = self.graph.query(query)
|
||||
return [r for r in res if isinstance(r, ResultRow)]
|
||||
307
venv/Lib/site-packages/langchain_community/graphs/rdf_graph.py
Normal file
307
venv/Lib/site-packages/langchain_community/graphs/rdf_graph.py
Normal file
@@ -0,0 +1,307 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import rdflib
|
||||
|
||||
prefixes = {
|
||||
"owl": """PREFIX owl: <http://www.w3.org/2002/07/owl#>\n""",
|
||||
"rdf": """PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>\n""",
|
||||
"rdfs": """PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>\n""",
|
||||
"xsd": """PREFIX xsd: <http://www.w3.org/2001/XMLSchema#>\n""",
|
||||
}
|
||||
|
||||
cls_query_rdf = prefixes["rdfs"] + (
|
||||
"""SELECT DISTINCT ?cls ?com\n"""
|
||||
"""WHERE { \n"""
|
||||
""" ?instance a ?cls . \n"""
|
||||
""" OPTIONAL { ?cls rdfs:comment ?com } \n"""
|
||||
"""}"""
|
||||
)
|
||||
|
||||
cls_query_rdfs = prefixes["rdfs"] + (
|
||||
"""SELECT DISTINCT ?cls ?com\n"""
|
||||
"""WHERE { \n"""
|
||||
""" ?instance a/rdfs:subClassOf* ?cls . \n"""
|
||||
""" OPTIONAL { ?cls rdfs:comment ?com } \n"""
|
||||
"""}"""
|
||||
)
|
||||
|
||||
cls_query_owl = prefixes["rdfs"] + (
|
||||
"""SELECT DISTINCT ?cls ?com\n"""
|
||||
"""WHERE { \n"""
|
||||
""" ?instance a/rdfs:subClassOf* ?cls . \n"""
|
||||
""" FILTER (isIRI(?cls)) . \n"""
|
||||
""" OPTIONAL { ?cls rdfs:comment ?com } \n"""
|
||||
"""}"""
|
||||
)
|
||||
|
||||
rel_query_rdf = prefixes["rdfs"] + (
|
||||
"""SELECT DISTINCT ?rel ?com\n"""
|
||||
"""WHERE { \n"""
|
||||
""" ?subj ?rel ?obj . \n"""
|
||||
""" OPTIONAL { ?rel rdfs:comment ?com } \n"""
|
||||
"""}"""
|
||||
)
|
||||
|
||||
rel_query_rdfs = (
|
||||
prefixes["rdf"]
|
||||
+ prefixes["rdfs"]
|
||||
+ (
|
||||
"""SELECT DISTINCT ?rel ?com\n"""
|
||||
"""WHERE { \n"""
|
||||
""" ?rel a/rdfs:subPropertyOf* rdf:Property . \n"""
|
||||
""" OPTIONAL { ?rel rdfs:comment ?com } \n"""
|
||||
"""}"""
|
||||
)
|
||||
)
|
||||
|
||||
op_query_owl = (
|
||||
prefixes["rdfs"]
|
||||
+ prefixes["owl"]
|
||||
+ (
|
||||
"""SELECT DISTINCT ?op ?com\n"""
|
||||
"""WHERE { \n"""
|
||||
""" ?op a/rdfs:subPropertyOf* owl:ObjectProperty . \n"""
|
||||
""" OPTIONAL { ?op rdfs:comment ?com } \n"""
|
||||
"""}"""
|
||||
)
|
||||
)
|
||||
|
||||
dp_query_owl = (
|
||||
prefixes["rdfs"]
|
||||
+ prefixes["owl"]
|
||||
+ (
|
||||
"""SELECT DISTINCT ?dp ?com\n"""
|
||||
"""WHERE { \n"""
|
||||
""" ?dp a/rdfs:subPropertyOf* owl:DatatypeProperty . \n"""
|
||||
""" OPTIONAL { ?dp rdfs:comment ?com } \n"""
|
||||
"""}"""
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class RdfGraph:
|
||||
"""RDFlib wrapper for graph operations.
|
||||
|
||||
Modes:
|
||||
* local: Local file - can be queried and changed
|
||||
* online: Online file - can only be queried, changes can be stored locally
|
||||
* store: Triple store - can be queried and changed if update_endpoint available
|
||||
Together with a source file, the serialization should be specified.
|
||||
|
||||
*Security note*: Make sure that the database connection uses credentials
|
||||
that are narrowly-scoped to only include necessary permissions.
|
||||
Failure to do so may result in data corruption or loss, since the calling
|
||||
code may attempt commands that would result in deletion, mutation
|
||||
of data if appropriately prompted or reading sensitive data if such
|
||||
data is present in the database.
|
||||
The best way to guard against such negative outcomes is to (as appropriate)
|
||||
limit the permissions granted to the credentials used with this tool.
|
||||
|
||||
See https://python.langchain.com/docs/security for more information.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
source_file: Optional[str] = None,
|
||||
serialization: Optional[str] = "ttl",
|
||||
query_endpoint: Optional[str] = None,
|
||||
update_endpoint: Optional[str] = None,
|
||||
standard: Optional[str] = "rdf",
|
||||
local_copy: Optional[str] = None,
|
||||
graph_kwargs: Optional[Dict] = None,
|
||||
store_kwargs: Optional[Dict] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Set up the RDFlib graph
|
||||
|
||||
:param source_file: either a path for a local file or a URL
|
||||
:param serialization: serialization of the input
|
||||
:param query_endpoint: SPARQL endpoint for queries, read access
|
||||
:param update_endpoint: SPARQL endpoint for UPDATE queries, write access
|
||||
:param standard: RDF, RDFS, or OWL
|
||||
:param local_copy: new local copy for storing changes
|
||||
:param graph_kwargs: Additional rdflib.Graph specific kwargs
|
||||
that will be used to initialize it,
|
||||
if query_endpoint is provided.
|
||||
:param store_kwargs: Additional sparqlstore.SPARQLStore specific kwargs
|
||||
that will be used to initialize it,
|
||||
if query_endpoint is provided.
|
||||
"""
|
||||
self.source_file = source_file
|
||||
self.serialization = serialization
|
||||
self.query_endpoint = query_endpoint
|
||||
self.update_endpoint = update_endpoint
|
||||
self.standard = standard
|
||||
self.local_copy = local_copy
|
||||
|
||||
try:
|
||||
import rdflib
|
||||
from rdflib.plugins.stores import sparqlstore
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import rdflib python package. "
|
||||
"Please install it with `pip install rdflib`."
|
||||
)
|
||||
if self.standard not in (supported_standards := ("rdf", "rdfs", "owl")):
|
||||
raise ValueError(
|
||||
f"Invalid standard. Supported standards are: {supported_standards}."
|
||||
)
|
||||
|
||||
if (
|
||||
not source_file
|
||||
and not query_endpoint
|
||||
or source_file
|
||||
and (query_endpoint or update_endpoint)
|
||||
):
|
||||
raise ValueError(
|
||||
"Could not unambiguously initialize the graph wrapper. "
|
||||
"Specify either a file (local or online) via the source_file "
|
||||
"or a triple store via the endpoints."
|
||||
)
|
||||
|
||||
if source_file:
|
||||
if source_file.startswith("http"):
|
||||
self.mode = "online"
|
||||
else:
|
||||
self.mode = "local"
|
||||
if self.local_copy is None:
|
||||
self.local_copy = self.source_file
|
||||
self.graph = rdflib.Graph()
|
||||
self.graph.parse(source_file, format=self.serialization)
|
||||
|
||||
if query_endpoint:
|
||||
store_kwargs = store_kwargs or {}
|
||||
self.mode = "store"
|
||||
if not update_endpoint:
|
||||
self._store = sparqlstore.SPARQLStore(**store_kwargs)
|
||||
self._store.open(query_endpoint)
|
||||
else:
|
||||
self._store = sparqlstore.SPARQLUpdateStore(**store_kwargs)
|
||||
self._store.open((query_endpoint, update_endpoint))
|
||||
graph_kwargs = graph_kwargs or {}
|
||||
self.graph = rdflib.Graph(self._store, **graph_kwargs)
|
||||
|
||||
# Verify that the graph was loaded
|
||||
if not len(self.graph):
|
||||
raise AssertionError("The graph is empty.")
|
||||
|
||||
# Set schema
|
||||
self.schema = ""
|
||||
self.load_schema()
|
||||
|
||||
@property
|
||||
def get_schema(self) -> str:
|
||||
"""
|
||||
Returns the schema of the graph database.
|
||||
"""
|
||||
return self.schema
|
||||
|
||||
def query(
|
||||
self,
|
||||
query: str,
|
||||
) -> List[rdflib.query.ResultRow]:
|
||||
"""
|
||||
Query the graph.
|
||||
"""
|
||||
from rdflib.exceptions import ParserError
|
||||
from rdflib.query import ResultRow
|
||||
|
||||
try:
|
||||
res = self.graph.query(query)
|
||||
except ParserError as e:
|
||||
raise ValueError(f"Generated SPARQL statement is invalid\n{e}")
|
||||
return [r for r in res if isinstance(r, ResultRow)]
|
||||
|
||||
def update(
|
||||
self,
|
||||
query: str,
|
||||
) -> None:
|
||||
"""
|
||||
Update the graph.
|
||||
"""
|
||||
from rdflib.exceptions import ParserError
|
||||
|
||||
try:
|
||||
self.graph.update(query)
|
||||
except ParserError as e:
|
||||
raise ValueError(f"Generated SPARQL statement is invalid\n{e}")
|
||||
if self.local_copy:
|
||||
self.graph.serialize(
|
||||
destination=self.local_copy, format=self.local_copy.split(".")[-1]
|
||||
)
|
||||
else:
|
||||
raise ValueError("No target file specified for saving the updated file.")
|
||||
|
||||
@staticmethod
|
||||
def _get_local_name(iri: str) -> str:
|
||||
if "#" in iri:
|
||||
local_name = iri.split("#")[-1]
|
||||
elif "/" in iri:
|
||||
local_name = iri.split("/")[-1]
|
||||
else:
|
||||
raise ValueError(f"Unexpected IRI '{iri}', contains neither '#' nor '/'.")
|
||||
return local_name
|
||||
|
||||
def _res_to_str(self, res: rdflib.query.ResultRow, var: str) -> str:
|
||||
return (
|
||||
"<"
|
||||
+ str(res[var])
|
||||
+ "> ("
|
||||
+ self._get_local_name(res[var])
|
||||
+ ", "
|
||||
+ str(res["com"])
|
||||
+ ")"
|
||||
)
|
||||
|
||||
def load_schema(self) -> None:
|
||||
"""
|
||||
Load the graph schema information.
|
||||
"""
|
||||
|
||||
def _rdf_s_schema(
|
||||
classes: List[rdflib.query.ResultRow],
|
||||
relationships: List[rdflib.query.ResultRow],
|
||||
) -> str:
|
||||
return (
|
||||
f"In the following, each IRI is followed by the local name and "
|
||||
f"optionally its description in parentheses. \n"
|
||||
f"The RDF graph supports the following node types:\n"
|
||||
f"{', '.join([self._res_to_str(r, 'cls') for r in classes])}\n"
|
||||
f"The RDF graph supports the following relationships:\n"
|
||||
f"{', '.join([self._res_to_str(r, 'rel') for r in relationships])}\n"
|
||||
)
|
||||
|
||||
if self.standard == "rdf":
|
||||
clss = self.query(cls_query_rdf)
|
||||
rels = self.query(rel_query_rdf)
|
||||
self.schema = _rdf_s_schema(clss, rels)
|
||||
elif self.standard == "rdfs":
|
||||
clss = self.query(cls_query_rdfs)
|
||||
rels = self.query(rel_query_rdfs)
|
||||
self.schema = _rdf_s_schema(clss, rels)
|
||||
elif self.standard == "owl":
|
||||
clss = self.query(cls_query_owl)
|
||||
ops = self.query(op_query_owl)
|
||||
dps = self.query(dp_query_owl)
|
||||
self.schema = (
|
||||
f"In the following, each IRI is followed by the local name and "
|
||||
f"optionally its description in parentheses. \n"
|
||||
f"The OWL graph supports the following node types:\n"
|
||||
f"{', '.join([self._res_to_str(r, 'cls') for r in clss])}\n"
|
||||
f"The OWL graph supports the following object properties, "
|
||||
f"i.e., relationships between objects:\n"
|
||||
f"{', '.join([self._res_to_str(r, 'op') for r in ops])}\n"
|
||||
f"The OWL graph supports the following data properties, "
|
||||
f"i.e., relationships between objects and literals:\n"
|
||||
f"{', '.join([self._res_to_str(r, 'dp') for r in dps])}\n"
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Mode '{self.standard}' is currently not supported.")
|
||||
@@ -0,0 +1,100 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_community.graphs.graph_store import GraphStore
|
||||
|
||||
|
||||
class TigerGraph(GraphStore):
|
||||
"""TigerGraph wrapper for graph operations.
|
||||
|
||||
*Security note*: Make sure that the database connection uses credentials
|
||||
that are narrowly-scoped to only include necessary permissions.
|
||||
Failure to do so may result in data corruption or loss, since the calling
|
||||
code may attempt commands that would result in deletion, mutation
|
||||
of data if appropriately prompted or reading sensitive data if such
|
||||
data is present in the database.
|
||||
The best way to guard against such negative outcomes is to (as appropriate)
|
||||
limit the permissions granted to the credentials used with this tool.
|
||||
|
||||
See https://python.langchain.com/docs/security for more information.
|
||||
"""
|
||||
|
||||
def __init__(self, conn: Any) -> None:
|
||||
"""Create a new TigerGraph graph wrapper instance."""
|
||||
self.set_connection(conn)
|
||||
self.set_schema()
|
||||
|
||||
@property
|
||||
def conn(self) -> Any:
|
||||
return self._conn
|
||||
|
||||
@property
|
||||
def schema(self) -> Dict[str, Any]:
|
||||
return self._schema
|
||||
|
||||
def get_schema(self) -> str: # type: ignore[override]
|
||||
if self._schema:
|
||||
return str(self._schema)
|
||||
else:
|
||||
self.set_schema()
|
||||
return str(self._schema)
|
||||
|
||||
def set_connection(self, conn: Any) -> None:
|
||||
try:
|
||||
from pyTigerGraph import TigerGraphConnection
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import pyTigerGraph python package. "
|
||||
"Please install it with `pip install pyTigerGraph`."
|
||||
)
|
||||
|
||||
if not isinstance(conn, TigerGraphConnection):
|
||||
msg = "**conn** parameter must inherit from TigerGraphConnection"
|
||||
raise TypeError(msg)
|
||||
|
||||
if conn.ai.nlqs_host is None:
|
||||
msg = """**conn** parameter does not have nlqs_host parameter defined.
|
||||
Define hostname of NLQS service."""
|
||||
raise ConnectionError(msg)
|
||||
|
||||
self._conn: TigerGraphConnection = conn
|
||||
self.set_schema()
|
||||
|
||||
def set_schema(self, schema: Optional[Dict[str, Any]] = None) -> None:
|
||||
"""
|
||||
Set the schema of the TigerGraph Database.
|
||||
Auto-generates Schema if **schema** is None.
|
||||
"""
|
||||
self._schema = self.generate_schema() if schema is None else schema
|
||||
|
||||
def generate_schema(
|
||||
self,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Generates the schema of the TigerGraph Database and returns it
|
||||
User can specify a **sample_ratio** (0 to 1) to determine the
|
||||
ratio of documents/edges used (in relation to the Collection size)
|
||||
to render each Collection Schema.
|
||||
"""
|
||||
return self._conn.getSchema(force=True)
|
||||
|
||||
def refresh_schema(self) -> None:
|
||||
self.generate_schema()
|
||||
|
||||
def query(self, query: str) -> Dict[str, Any]: # type: ignore[override]
|
||||
"""Query the TigerGraph database."""
|
||||
answer = self._conn.ai.query(query)
|
||||
return answer
|
||||
|
||||
def register_query(
|
||||
self,
|
||||
function_header: str,
|
||||
description: str,
|
||||
docstring: str,
|
||||
param_types: dict = {},
|
||||
) -> List[str]:
|
||||
"""
|
||||
Wrapper function to register a custom GSQL query to the TigerGraph NLQS.
|
||||
"""
|
||||
return self._conn.ai.registerCustomQuery(
|
||||
function_header, description, docstring, param_types
|
||||
)
|
||||
Reference in New Issue
Block a user