initial commit

This commit is contained in:
2026-05-11 12:36:20 +05:30
commit 384cbe8019
15377 changed files with 2360544 additions and 0 deletions

View File

@@ -0,0 +1,95 @@
"""**Graphs** provide a natural language interface to graph databases."""
import importlib
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from langchain_community.graphs.arangodb_graph import (
ArangoGraph,
)
from langchain_community.graphs.falkordb_graph import (
FalkorDBGraph,
)
from langchain_community.graphs.gremlin_graph import (
GremlinGraph,
)
from langchain_community.graphs.hugegraph import (
HugeGraph,
)
from langchain_community.graphs.kuzu_graph import (
KuzuGraph,
)
from langchain_community.graphs.memgraph_graph import (
MemgraphGraph,
)
from langchain_community.graphs.nebula_graph import (
NebulaGraph,
)
from langchain_community.graphs.neo4j_graph import (
Neo4jGraph,
)
from langchain_community.graphs.neptune_graph import (
BaseNeptuneGraph,
NeptuneAnalyticsGraph,
NeptuneGraph,
)
from langchain_community.graphs.neptune_rdf_graph import (
NeptuneRdfGraph,
)
from langchain_community.graphs.networkx_graph import (
NetworkxEntityGraph,
)
from langchain_community.graphs.ontotext_graphdb_graph import (
OntotextGraphDBGraph,
)
from langchain_community.graphs.rdf_graph import (
RdfGraph,
)
from langchain_community.graphs.tigergraph_graph import (
TigerGraph,
)
__all__ = [
"ArangoGraph",
"FalkorDBGraph",
"GremlinGraph",
"HugeGraph",
"KuzuGraph",
"BaseNeptuneGraph",
"MemgraphGraph",
"NebulaGraph",
"Neo4jGraph",
"NeptuneGraph",
"NeptuneRdfGraph",
"NeptuneAnalyticsGraph",
"NetworkxEntityGraph",
"OntotextGraphDBGraph",
"RdfGraph",
"TigerGraph",
]
_module_lookup = {
"ArangoGraph": "langchain_community.graphs.arangodb_graph",
"FalkorDBGraph": "langchain_community.graphs.falkordb_graph",
"GremlinGraph": "langchain_community.graphs.gremlin_graph",
"HugeGraph": "langchain_community.graphs.hugegraph",
"KuzuGraph": "langchain_community.graphs.kuzu_graph",
"MemgraphGraph": "langchain_community.graphs.memgraph_graph",
"NebulaGraph": "langchain_community.graphs.nebula_graph",
"Neo4jGraph": "langchain_community.graphs.neo4j_graph",
"BaseNeptuneGraph": "langchain_community.graphs.neptune_graph",
"NeptuneAnalyticsGraph": "langchain_community.graphs.neptune_graph",
"NeptuneGraph": "langchain_community.graphs.neptune_graph",
"NeptuneRdfGraph": "langchain_community.graphs.neptune_rdf_graph",
"NetworkxEntityGraph": "langchain_community.graphs.networkx_graph",
"OntotextGraphDBGraph": "langchain_community.graphs.ontotext_graphdb_graph",
"RdfGraph": "langchain_community.graphs.rdf_graph",
"TigerGraph": "langchain_community.graphs.tigergraph_graph",
}
def __getattr__(name: str) -> Any:
if name in _module_lookup:
module = importlib.import_module(_module_lookup[name])
return getattr(module, name)
raise AttributeError(f"module {__name__} has no attribute {name}")

View File

@@ -0,0 +1,765 @@
from __future__ import annotations
import json
import re
from hashlib import md5
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Pattern, Tuple, Union
from langchain_community.graphs.graph_document import GraphDocument
from langchain_community.graphs.graph_store import GraphStore
if TYPE_CHECKING:
import psycopg2.extras
class AGEQueryException(Exception):
"""Exception for the AGE queries."""
def __init__(self, exception: Union[str, Dict]) -> None:
if isinstance(exception, dict):
self.message = exception["message"] if "message" in exception else "unknown"
self.details = exception["details"] if "details" in exception else "unknown"
else:
self.message = exception
self.details = "unknown"
def get_message(self) -> str:
return self.message
def get_details(self) -> Any:
return self.details
class AGEGraph(GraphStore):
"""
Apache AGE wrapper for graph operations.
Args:
graph_name (str): the name of the graph to connect to or create
conf (Dict[str, Any]): the pgsql connection config passed directly
to psycopg2.connect
create (bool): if True and graph doesn't exist, attempt to create it
*Security note*: Make sure that the database connection uses credentials
that are narrowly-scoped to only include necessary permissions.
Failure to do so may result in data corruption or loss, since the calling
code may attempt commands that would result in deletion, mutation
of data if appropriately prompted or reading sensitive data if such
data is present in the database.
The best way to guard against such negative outcomes is to (as appropriate)
limit the permissions granted to the credentials used with this tool.
See https://python.langchain.com/docs/security for more information.
"""
# python type mapping for providing readable types to LLM
types = {
"str": "STRING",
"float": "DOUBLE",
"int": "INTEGER",
"list": "LIST",
"dict": "MAP",
"bool": "BOOLEAN",
}
# precompiled regex for checking chars in graph labels
label_regex: Pattern = re.compile("[^0-9a-zA-Z]+")
def __init__(
self, graph_name: str, conf: Dict[str, Any], create: bool = True
) -> None:
"""Create a new AGEGraph instance."""
self.graph_name = graph_name
# check that psycopg2 is installed
try:
import psycopg2
except ImportError:
raise ImportError(
"Could not import psycopg2 python package. "
"Please install it with `pip install psycopg2`."
)
self.connection = psycopg2.connect(**conf)
with self._get_cursor() as curs:
# check if graph with name graph_name exists
graph_id_query = (
"""SELECT graphid FROM ag_catalog.ag_graph WHERE name = '{}'""".format(
graph_name
)
)
curs.execute(graph_id_query)
data = curs.fetchone()
# if graph doesn't exist and create is True, create it
if data is None:
if create:
create_statement = """
SELECT ag_catalog.create_graph('{}');
""".format(graph_name)
try:
curs.execute(create_statement)
self.connection.commit()
except psycopg2.Error as e:
raise AGEQueryException(
{
"message": "Could not create the graph",
"detail": str(e),
}
)
else:
raise Exception(
(
'Graph "{}" does not exist in the database '
+ 'and "create" is set to False'
).format(graph_name)
)
curs.execute(graph_id_query)
data = curs.fetchone()
# store graph id and refresh the schema
self.graphid = data.graphid
self.refresh_schema()
def _get_cursor(self) -> psycopg2.extras.NamedTupleCursor:
"""
get cursor, load age extension and set search path
"""
try:
import psycopg2.extras
except ImportError as e:
raise ImportError(
"Unable to import psycopg2, please install with "
"`pip install -U psycopg2`."
) from e
cursor = self.connection.cursor(cursor_factory=psycopg2.extras.NamedTupleCursor)
cursor.execute("""LOAD 'age';""")
cursor.execute("""SET search_path = ag_catalog, "$user", public;""")
return cursor
def _get_labels(self) -> Tuple[List[str], List[str]]:
"""
Get all labels of a graph (for both edges and vertices)
by querying the graph metadata table directly
Returns
Tuple[List[str]]: 2 lists, the first containing vertex
labels and the second containing edge labels
"""
e_labels_records = self.query(
"""MATCH ()-[e]-() RETURN collect(distinct label(e)) as labels"""
)
e_labels = e_labels_records[0]["labels"] if e_labels_records else []
n_labels_records = self.query(
"""MATCH (n) RETURN collect(distinct label(n)) as labels"""
)
n_labels = n_labels_records[0]["labels"] if n_labels_records else []
return n_labels, e_labels
def _get_triples(self, e_labels: List[str]) -> List[Dict[str, str]]:
"""
Get a set of distinct relationship types (as a list of dicts) in the graph
to be used as context by an llm.
Args:
e_labels (List[str]): a list of edge labels to filter for
Returns:
List[Dict[str, str]]: relationships as a list of dicts in the format
"{'start':<from_label>, 'type':<edge_label>, 'end':<from_label>}"
"""
# age query to get distinct relationship types
try:
import psycopg2
except ImportError as e:
raise ImportError(
"Unable to import psycopg2, please install with "
"`pip install -U psycopg2`."
) from e
triple_query = """
SELECT * FROM ag_catalog.cypher('{graph_name}', $$
MATCH (a)-[e:`{e_label}`]->(b)
WITH a,e,b LIMIT 3000
RETURN DISTINCT labels(a) AS from, type(e) AS edge, labels(b) AS to
LIMIT 10
$$) AS (f agtype, edge agtype, t agtype);
"""
triple_schema = []
# iterate desired edge types and add distinct relationship types to result
with self._get_cursor() as curs:
for label in e_labels:
q = triple_query.format(graph_name=self.graph_name, e_label=label)
try:
curs.execute(q)
data = curs.fetchall()
for d in data:
# use json.loads to convert returned
# strings to python primitives
triple_schema.append(
{
"start": json.loads(d.f)[0],
"type": json.loads(d.edge),
"end": json.loads(d.t)[0],
}
)
except psycopg2.Error as e:
raise AGEQueryException(
{
"message": "Error fetching triples",
"detail": str(e),
}
)
return triple_schema
def _get_triples_str(self, e_labels: List[str]) -> List[str]:
"""
Get a set of distinct relationship types (as a list of strings) in the graph
to be used as context by an llm.
Args:
e_labels (List[str]): a list of edge labels to filter for
Returns:
List[str]: relationships as a list of strings in the format
"(:`<from_label>`)-[:`<edge_label>`]->(:`<to_label>`)"
"""
triples = self._get_triples(e_labels)
return self._format_triples(triples)
@staticmethod
def _format_triples(triples: List[Dict[str, str]]) -> List[str]:
"""
Convert a list of relationships from dictionaries to formatted strings
to be better readable by an llm
Args:
triples (List[Dict[str,str]]): a list relationships in the form
{'start':<from_label>, 'type':<edge_label>, 'end':<from_label>}
Returns:
List[str]: a list of relationships in the form
"(:`<from_label>`)-[:`<edge_label>`]->(:`<to_label>`)"
"""
triple_template = "(:`{start}`)-[:`{type}`]->(:`{end}`)"
triple_schema = [triple_template.format(**triple) for triple in triples]
return triple_schema
def _get_node_properties(self, n_labels: List[str]) -> List[Dict[str, Any]]:
"""
Fetch a list of available node properties by node label to be used
as context for an llm
Args:
n_labels (List[str]): a list of node labels to filter for
Returns:
List[Dict[str, Any]]: a list of node labels and
their corresponding properties in the form
"{
'labels': <node_label>,
'properties': [
{
'property': <property_name>,
'type': <property_type>
},...
]
}"
"""
try:
import psycopg2
except ImportError as e:
raise ImportError(
"Unable to import psycopg2, please install with "
"`pip install -U psycopg2`."
) from e
# cypher query to fetch properties of a given label
node_properties_query = """
SELECT * FROM ag_catalog.cypher('{graph_name}', $$
MATCH (a:`{n_label}`)
RETURN properties(a) AS props
LIMIT 100
$$) AS (props agtype);
"""
node_properties = []
with self._get_cursor() as curs:
for label in n_labels:
q = node_properties_query.format(
graph_name=self.graph_name, n_label=label
)
try:
curs.execute(q)
except psycopg2.Error as e:
raise AGEQueryException(
{
"message": "Error fetching node properties",
"detail": str(e),
}
)
data = curs.fetchall()
# build a set of distinct properties
s = set({})
for d in data:
# use json.loads to convert to python
# primitive and get readable type
for k, v in json.loads(d.props).items():
s.add((k, self.types[type(v).__name__]))
np = {
"properties": [{"property": k, "type": v} for k, v in s],
"labels": label,
}
node_properties.append(np)
return node_properties
def _get_edge_properties(self, e_labels: List[str]) -> List[Dict[str, Any]]:
"""
Fetch a list of available edge properties by edge label to be used
as context for an llm
Args:
e_labels (List[str]): a list of edge labels to filter for
Returns:
List[Dict[str, Any]]: a list of edge labels
and their corresponding properties in the form
"{
'labels': <edge_label>,
'properties': [
{
'property': <property_name>,
'type': <property_type>
},...
]
}"
"""
try:
import psycopg2
except ImportError as e:
raise ImportError(
"Unable to import psycopg2, please install with "
"`pip install -U psycopg2`."
) from e
# cypher query to fetch properties of a given label
edge_properties_query = """
SELECT * FROM ag_catalog.cypher('{graph_name}', $$
MATCH ()-[e:`{e_label}`]->()
RETURN properties(e) AS props
LIMIT 100
$$) AS (props agtype);
"""
edge_properties = []
with self._get_cursor() as curs:
for label in e_labels:
q = edge_properties_query.format(
graph_name=self.graph_name, e_label=label
)
try:
curs.execute(q)
except psycopg2.Error as e:
raise AGEQueryException(
{
"message": "Error fetching edge properties",
"detail": str(e),
}
)
data = curs.fetchall()
# build a set of distinct properties
s = set({})
for d in data:
# use json.loads to convert to python
# primitive and get readable type
for k, v in json.loads(d.props).items():
s.add((k, self.types[type(v).__name__]))
np = {
"properties": [{"property": k, "type": v} for k, v in s],
"type": label,
}
edge_properties.append(np)
return edge_properties
def refresh_schema(self) -> None:
"""
Refresh the graph schema information by updating the available
labels, relationships, and properties
"""
# fetch graph schema information
n_labels, e_labels = self._get_labels()
triple_schema = self._get_triples(e_labels)
node_properties = self._get_node_properties(n_labels)
edge_properties = self._get_edge_properties(e_labels)
# update the formatted string representation
self.schema = f"""
Node properties are the following:
{node_properties}
Relationship properties are the following:
{edge_properties}
The relationships are the following:
{self._format_triples(triple_schema)}
"""
# update the dictionary representation
self.structured_schema = {
"node_props": {el["labels"]: el["properties"] for el in node_properties},
"rel_props": {el["type"]: el["properties"] for el in edge_properties},
"relationships": triple_schema,
"metadata": {},
}
@property
def get_schema(self) -> str:
"""Returns the schema of the Graph"""
return self.schema
@property
def get_structured_schema(self) -> Dict[str, Any]:
"""Returns the structured schema of the Graph"""
return self.structured_schema
@staticmethod
def _get_col_name(field: str, idx: int) -> str:
"""
Convert a cypher return field to a pgsql select field
If possible keep the cypher column name, but create a generic name if necessary
Args:
field (str): a return field from a cypher query to be formatted for pgsql
idx (int): the position of the field in the return statement
Returns:
str: the field to be used in the pgsql select statement
"""
# remove white space
field = field.strip()
# if an alias is provided for the field, use it
if " as " in field:
return field.split(" as ")[-1].strip()
# if the return value is an unnamed primitive, give it a generic name
elif field.isnumeric() or field in ("true", "false", "null"):
return f"column_{idx}"
# otherwise return the value stripping out some common special chars
else:
return field.replace("(", "_").replace(")", "")
@staticmethod
def _wrap_query(query: str, graph_name: str) -> str:
"""
Convert a Cyper query to an Apache Age compatible Sql Query.
Handles combined queries with UNION/EXCEPT operators
Args:
query (str) : A valid cypher query, can include UNION/EXCEPT operators
graph_name (str) : The name of the graph to query
Returns :
str : An equivalent pgSql query wrapped with ag_catalog.cypher
Raises:
ValueError : If query is empty, contain RETURN *, or has invalid field names
"""
if not query.strip():
raise ValueError("Empty query provided")
# pgsql template
template = """SELECT {projection} FROM ag_catalog.cypher('{graph_name}', $$
{query}
$$) AS ({fields});"""
# split the query into parts based on UNION and EXCEPT
parts = re.split(r"\b(UNION\b|\bEXCEPT)\b", query, flags=re.IGNORECASE)
all_fields = []
for part in parts:
if part.strip().upper() in ("UNION", "EXCEPT"):
continue
# if there are any returned fields they must be added to the pgsql query
return_match = re.search(r'\breturn\b(?![^"]*")', part, re.IGNORECASE)
if return_match:
# Extract the part of the query after the RETURN keyword
return_clause = part[return_match.end() :]
# parse return statement to identify returned fields
fields = (
return_clause.lower()
.split("distinct")[-1]
.split("order by")[0]
.split("skip")[0]
.split("limit")[0]
.split(",")
)
# raise exception if RETURN * is found as we can't resolve the fields
clean_fileds = [f.strip() for f in fields if f.strip()]
if "*" in clean_fileds:
raise ValueError(
"Apache Age does not support RETURN * in Cypher queries"
)
# Format fields and maintain order of appearance
for idx, field in enumerate(clean_fileds):
field_name = AGEGraph._get_col_name(field, idx)
if field_name not in all_fields:
all_fields.append(field_name)
# if no return statements found in any part
if not all_fields:
fields_str = "a agtype"
else:
fields_str = ", ".join(f"{field} agtype" for field in all_fields)
return template.format(
graph_name=graph_name,
query=query,
fields=fields_str,
projection="*",
)
@staticmethod
def _record_to_dict(record: NamedTuple) -> Dict[str, Any]:
"""
Convert a record returned from an age query to a dictionary
Args:
record (): a record from an age query result
Returns:
Dict[str, Any]: a dictionary representation of the record where
the dictionary key is the field name and the value is the
value converted to a python type
"""
# result holder
d = {}
# prebuild a mapping of vertex_id to vertex mappings to be used
# later to build edges
vertices = {}
for k in record._fields:
v = getattr(record, k)
# agtype comes back '{key: value}::type' which must be parsed
if isinstance(v, str) and "::" in v:
dtype = v.split("::")[-1]
v = v.split("::")[0]
if dtype == "vertex":
vertex = json.loads(v)
vertices[vertex["id"]] = vertex.get("properties")
# iterate returned fields and parse appropriately
for k in record._fields:
v = getattr(record, k)
if isinstance(v, str) and "::" in v:
dtype = v.split("::")[-1]
v = v.split("::")[0]
else:
dtype = ""
if dtype == "vertex":
d[k] = json.loads(v).get("properties")
# convert edge from id-label->id by replacing id with node information
# we only do this if the vertex was also returned in the query
# this is an attempt to be consistent with neo4j implementation
elif dtype == "edge":
edge = json.loads(v)
d[k] = (
vertices.get(edge["start_id"], {}),
edge["label"],
vertices.get(edge["end_id"], {}),
)
else:
d[k] = json.loads(v) if isinstance(v, str) else v
return d
def query(self, query: str, params: dict = {}) -> List[Dict[str, Any]]:
"""
Query the graph by taking a cypher query, converting it to an
age compatible query, executing it and converting the result
Args:
query (str): a cypher query to be executed
params (dict): parameters for the query (not used in this implementation)
Returns:
List[Dict[str, Any]]: a list of dictionaries containing the result set
"""
try:
import psycopg2
except ImportError as e:
raise ImportError(
"Unable to import psycopg2, please install with "
"`pip install -U psycopg2`."
) from e
# convert cypher query to pgsql/age query
wrapped_query = self._wrap_query(query, self.graph_name)
# execute the query, rolling back on an error
with self._get_cursor() as curs:
try:
curs.execute(wrapped_query)
self.connection.commit()
except psycopg2.Error as e:
self.connection.rollback()
raise AGEQueryException(
{
"message": "Error executing graph query: {}".format(query),
"detail": str(e),
}
)
data = curs.fetchall()
if data is None:
result = []
# convert to dictionaries
else:
result = [self._record_to_dict(d) for d in data]
return result
@staticmethod
def _format_properties(
properties: Dict[str, Any], id: Union[str, None] = None
) -> str:
"""
Convert a dictionary of properties to a string representation that
can be used in a cypher query insert/merge statement.
Args:
properties (Dict[str,str]): a dictionary containing node/edge properties
id (Union[str, None]): the id of the node or None if none exists
Returns:
str: the properties dictionary as a properly formatted string
"""
props = []
# wrap property key in backticks to escape
for k, v in properties.items():
prop = f"`{k}`: {json.dumps(v)}"
props.append(prop)
if id is not None and "id" not in properties:
props.append(
f"id: {json.dumps(id)}" if isinstance(id, str) else f"id: {id}"
)
return "{" + ", ".join(props) + "}"
@staticmethod
def clean_graph_labels(label: str) -> str:
"""
remove any disallowed characters from a label and replace with '_'
Args:
label (str): the original label
Returns:
str: the sanitized version of the label
"""
return re.sub(AGEGraph.label_regex, "_", label)
def add_graph_documents(
self, graph_documents: List[GraphDocument], include_source: bool = False
) -> None:
"""
insert a list of graph documents into the graph
Args:
graph_documents (List[GraphDocument]): the list of documents to be inserted
include_source (bool): if True add nodes for the sources
with MENTIONS edges to the entities they mention
Returns:
None
"""
# query for inserting nodes
node_insert_query = (
"""
MERGE (n:`{label}` {{`id`: "{id}"}})
SET n = {properties}
"""
if not include_source
else """
MERGE (n:`{label}` {properties})
MERGE (d:Document {d_properties})
MERGE (d)-[:MENTIONS]->(n)
"""
)
# query for inserting edges
edge_insert_query = """
MERGE (from:`{f_label}` {f_properties})
MERGE (to:`{t_label}` {t_properties})
MERGE (from)-[:`{r_label}` {r_properties}]->(to)
"""
# iterate docs and insert them
for doc in graph_documents:
# if we are adding sources, create an id for the source
if include_source:
if not doc.source.metadata.get("id"):
doc.source.metadata["id"] = md5(
doc.source.page_content.encode("utf-8")
).hexdigest()
# insert entity nodes
for node in doc.nodes:
node.properties["id"] = node.id
if include_source:
query = node_insert_query.format(
label=node.type,
properties=self._format_properties(node.properties),
d_properties=self._format_properties(doc.source.metadata),
)
else:
query = node_insert_query.format(
label=AGEGraph.clean_graph_labels(node.type),
properties=self._format_properties(node.properties),
id=node.id,
)
self.query(query)
# insert relationships
for edge in doc.relationships:
edge.source.properties["id"] = edge.source.id
edge.target.properties["id"] = edge.target.id
inputs = {
"f_label": AGEGraph.clean_graph_labels(edge.source.type),
"f_properties": self._format_properties(edge.source.properties),
"t_label": AGEGraph.clean_graph_labels(edge.target.type),
"t_properties": self._format_properties(edge.target.properties),
"r_label": AGEGraph.clean_graph_labels(edge.type).upper(),
"r_properties": self._format_properties(edge.properties),
}
query = edge_insert_query.format(**inputs)
self.query(query)

View File

@@ -0,0 +1,182 @@
import os
from math import ceil
from typing import Any, Dict, List, Optional
class ArangoGraph:
"""ArangoDB wrapper for graph operations.
*Security note*: Make sure that the database connection uses credentials
that are narrowly-scoped to only include necessary permissions.
Failure to do so may result in data corruption or loss, since the calling
code may attempt commands that would result in deletion, mutation
of data if appropriately prompted or reading sensitive data if such
data is present in the database.
The best way to guard against such negative outcomes is to (as appropriate)
limit the permissions granted to the credentials used with this tool.
See https://python.langchain.com/docs/security for more information.
"""
def __init__(self, db: Any) -> None:
"""Create a new ArangoDB graph wrapper instance."""
self.set_db(db)
self.set_schema()
@property
def db(self) -> Any:
return self.__db
@property
def schema(self) -> Dict[str, Any]:
return self.__schema
def set_db(self, db: Any) -> None:
from arango.database import Database
if not isinstance(db, Database):
msg = "**db** parameter must inherit from arango.database.Database"
raise TypeError(msg)
self.__db: Database = db
self.set_schema()
def set_schema(self, schema: Optional[Dict[str, Any]] = None) -> None:
"""
Set the schema of the ArangoDB Database.
Auto-generates Schema if **schema** is None.
"""
self.__schema = self.generate_schema() if schema is None else schema
def generate_schema(
self, sample_ratio: float = 0
) -> Dict[str, List[Dict[str, Any]]]:
"""
Generates the schema of the ArangoDB Database and returns it
User can specify a **sample_ratio** (0 to 1) to determine the
ratio of documents/edges used (in relation to the Collection size)
to render each Collection Schema.
"""
if not 0 <= sample_ratio <= 1:
raise ValueError("**sample_ratio** value must be in between 0 to 1")
# Stores the Edge Relationships between each ArangoDB Document Collection
graph_schema: List[Dict[str, Any]] = [
{"graph_name": g["name"], "edge_definitions": g["edge_definitions"]}
for g in self.db.graphs()
]
# Stores the schema of every ArangoDB Document/Edge collection
collection_schema: List[Dict[str, Any]] = []
for collection in self.db.collections():
if collection["system"]:
continue
# Extract collection name, type, and size
col_name: str = collection["name"]
col_type: str = collection["type"]
col_size: int = self.db.collection(col_name).count()
# Skip collection if empty
if col_size == 0:
continue
# Set number of ArangoDB documents/edges to retrieve
limit_amount = ceil(sample_ratio * col_size) or 1
aql = f"""
FOR doc in `{col_name}`
LIMIT {limit_amount}
RETURN doc
"""
doc: Dict[str, Any]
properties: List[Dict[str, str]] = []
for doc in self.__db.aql.execute(aql):
for key, value in doc.items():
properties.append({"name": key, "type": type(value).__name__})
collection_schema.append(
{
"collection_name": col_name,
"collection_type": col_type,
f"{col_type}_properties": properties,
f"example_{col_type}": doc,
}
)
return {"Graph Schema": graph_schema, "Collection Schema": collection_schema}
def query(
self, query: str, top_k: Optional[int] = None, **kwargs: Any
) -> List[Dict[str, Any]]:
"""Query the ArangoDB database."""
import itertools
cursor = self.__db.aql.execute(query, **kwargs)
return [doc for doc in itertools.islice(cursor, top_k)]
@classmethod
def from_db_credentials(
cls,
url: Optional[str] = None,
dbname: Optional[str] = None,
username: Optional[str] = None,
password: Optional[str] = None,
) -> Any:
"""Convenience constructor that builds Arango DB from credentials.
Args:
url: Arango DB url. Can be passed in as named arg or set as environment
var ``ARANGODB_URL``. Defaults to "http://localhost:8529".
dbname: Arango DB name. Can be passed in as named arg or set as
environment var ``ARANGODB_DBNAME``. Defaults to "_system".
username: Can be passed in as named arg or set as environment var
``ARANGODB_USERNAME``. Defaults to "root".
password: Can be passed ni as named arg or set as environment var
``ARANGODB_PASSWORD``. Defaults to "".
Returns:
An arango.database.StandardDatabase.
"""
db = get_arangodb_client(
url=url, dbname=dbname, username=username, password=password
)
return cls(db)
def get_arangodb_client(
url: Optional[str] = None,
dbname: Optional[str] = None,
username: Optional[str] = None,
password: Optional[str] = None,
) -> Any:
"""Get the Arango DB client from credentials.
Args:
url: Arango DB url. Can be passed in as named arg or set as environment
var ``ARANGODB_URL``. Defaults to "http://localhost:8529".
dbname: Arango DB name. Can be passed in as named arg or set as
environment var ``ARANGODB_DBNAME``. Defaults to "_system".
username: Can be passed in as named arg or set as environment var
``ARANGODB_USERNAME``. Defaults to "root".
password: Can be passed ni as named arg or set as environment var
``ARANGODB_PASSWORD``. Defaults to "".
Returns:
An arango.database.StandardDatabase.
"""
try:
from arango import ArangoClient
except ImportError as e:
raise ImportError(
"Unable to import arango, please install with `pip install python-arango`."
) from e
_url: str = url or os.environ.get("ARANGODB_URL", "http://localhost:8529")
_dbname: str = dbname or os.environ.get("ARANGODB_DBNAME", "_system")
_username: str = username or os.environ.get("ARANGODB_USERNAME", "root")
_password: str = password or os.environ.get("ARANGODB_PASSWORD", "")
return ArangoClient(_url).db(_dbname, _username, _password, verify=True)

View File

@@ -0,0 +1,201 @@
import warnings
from typing import Any, Dict, List, Optional
from langchain_core._api import deprecated
from langchain_community.graphs.graph_document import GraphDocument
from langchain_community.graphs.graph_store import GraphStore
node_properties_query = """
MATCH (n)
WITH keys(n) as keys, labels(n) AS labels
WITH CASE WHEN keys = [] THEN [NULL] ELSE keys END AS keys, labels
UNWIND labels AS label
UNWIND keys AS key
WITH label, collect(DISTINCT key) AS keys
RETURN {label:label, keys:keys} AS output
"""
rel_properties_query = """
MATCH ()-[r]->()
WITH keys(r) as keys, type(r) AS types
WITH CASE WHEN keys = [] THEN [NULL] ELSE keys END AS keys, types
UNWIND types AS type
UNWIND keys AS key WITH type,
collect(DISTINCT key) AS keys
RETURN {types:type, keys:keys} AS output
"""
rel_query = """
MATCH (n)-[r]->(m)
UNWIND labels(n) as src_label
UNWIND labels(m) as dst_label
UNWIND type(r) as rel_type
RETURN DISTINCT {start: src_label, type: rel_type, end: dst_label} AS output
"""
class FalkorDBGraph(GraphStore):
"""FalkorDB wrapper for graph operations.
*Security note*: Make sure that the database connection uses credentials
that are narrowly-scoped to only include necessary permissions.
Failure to do so may result in data corruption or loss, since the calling
code may attempt commands that would result in deletion, mutation
of data if appropriately prompted or reading sensitive data if such
data is present in the database.
The best way to guard against such negative outcomes is to (as appropriate)
limit the permissions granted to the credentials used with this tool.
See https://python.langchain.com/docs/security for more information.
"""
def __init__(
self,
database: str,
host: str = "localhost",
port: int = 6379,
username: Optional[str] = None,
password: Optional[str] = None,
ssl: bool = False,
) -> None:
"""Create a new FalkorDB graph wrapper instance."""
try:
self.__init_falkordb_connection(
database, host, port, username, password, ssl
)
except ImportError:
try:
# Falls back to using the redis package just for backwards compatibility
self.__init_redis_connection(
database, host, port, username, password, ssl
)
except ImportError:
raise ImportError(
"Could not import falkordb python package. "
"Please install it with `pip install falkordb`."
)
self.schema: str = ""
self.structured_schema: Dict[str, Any] = {}
try:
self.refresh_schema()
except Exception as e:
raise ValueError(f"Could not refresh schema. Error: {e}")
def __init_falkordb_connection(
self,
database: str,
host: str = "localhost",
port: int = 6379,
username: Optional[str] = None,
password: Optional[str] = None,
ssl: bool = False,
) -> None:
from falkordb import FalkorDB
try:
self._driver = FalkorDB(
host=host, port=port, username=username, password=password, ssl=ssl
)
except Exception as e:
raise ConnectionError(f"Failed to connect to FalkorDB: {e}")
self._graph = self._driver.select_graph(database)
@deprecated("0.0.31", alternative="__init_falkordb_connection")
def __init_redis_connection(
self,
database: str,
host: str = "localhost",
port: int = 6379,
username: Optional[str] = None,
password: Optional[str] = None,
ssl: bool = False,
) -> None:
import redis
from redis.commands.graph import Graph
# show deprecation warning
warnings.warn(
"Using the redis package is deprecated. "
"Please use the falkordb package instead, "
"install it with `pip install falkordb`.",
DeprecationWarning,
)
self._driver = redis.Redis(
host=host, port=port, username=username, password=password, ssl=ssl
)
self._graph = Graph(self._driver, database)
@property
def get_schema(self) -> str:
"""Returns the schema of the FalkorDB database"""
return self.schema
@property
def get_structured_schema(self) -> Dict[str, Any]:
"""Returns the structured schema of the Graph"""
return self.structured_schema
def refresh_schema(self) -> None:
"""Refreshes the schema of the FalkorDB database"""
node_properties: List[Any] = self.query(node_properties_query)
rel_properties: List[Any] = self.query(rel_properties_query)
relationships: List[Any] = self.query(rel_query)
self.structured_schema = {
"node_props": {el[0]["label"]: el[0]["keys"] for el in node_properties},
"rel_props": {el[0]["types"]: el[0]["keys"] for el in rel_properties},
"relationships": [el[0] for el in relationships],
}
self.schema = (
f"Node properties: {node_properties}\n"
f"Relationships properties: {rel_properties}\n"
f"Relationships: {relationships}\n"
)
def query(self, query: str, params: dict = {}) -> List[Dict[str, Any]]:
"""Query FalkorDB database."""
try:
data = self._graph.query(query, params)
return data.result_set
except Exception as e:
raise ValueError(f"Generated Cypher Statement is not valid\n{e}")
def add_graph_documents(
self, graph_documents: List[GraphDocument], include_source: bool = False
) -> None:
"""
Take GraphDocument as input as uses it to construct a graph.
"""
for document in graph_documents:
# Import nodes
for node in document.nodes:
self.query(
(
f"MERGE (n:{node.type} {{id:'{node.id}'}}) "
"SET n += $properties "
"RETURN distinct 'done' AS result"
),
{"properties": node.properties},
)
# Import relationships
for rel in document.relationships:
self.query(
(
f"MATCH (a:{rel.source.type} {{id:'{rel.source.id}'}}), "
f"(b:{rel.target.type} {{id:'{rel.target.id}'}}) "
f"MERGE (a)-[r:{(rel.type.replace(' ', '_').upper())}]->(b) "
"SET r += $properties "
"RETURN distinct 'done' AS result"
),
{"properties": rel.properties},
)

View File

@@ -0,0 +1,51 @@
from __future__ import annotations
from typing import List, Union
from langchain_core.documents import Document
from langchain_core.load.serializable import Serializable
from pydantic import Field
class Node(Serializable):
"""Represents a node in a graph with associated properties.
Attributes:
id (Union[str, int]): A unique identifier for the node.
type (str): The type or label of the node, default is "Node".
properties (dict): Additional properties and metadata associated with the node.
"""
id: Union[str, int]
type: str = "Node"
properties: dict = Field(default_factory=dict)
class Relationship(Serializable):
"""Represents a directed relationship between two nodes in a graph.
Attributes:
source (Node): The source node of the relationship.
target (Node): The target node of the relationship.
type (str): The type of the relationship.
properties (dict): Additional properties associated with the relationship.
"""
source: Node
target: Node
type: str
properties: dict = Field(default_factory=dict)
class GraphDocument(Serializable):
"""Represents a graph document consisting of nodes and relationships.
Attributes:
nodes (List[Node]): A list of nodes in the graph.
relationships (List[Relationship]): A list of relationships in the graph.
source (Document): The document from which the graph information is derived.
"""
nodes: List[Node]
relationships: List[Relationship]
source: Document

View File

@@ -0,0 +1,37 @@
from abc import abstractmethod
from typing import Any, Dict, List
from langchain_community.graphs.graph_document import GraphDocument
class GraphStore:
"""Abstract class for graph operations."""
@property
@abstractmethod
def get_schema(self) -> str:
"""Return the schema of the Graph database"""
pass
@property
@abstractmethod
def get_structured_schema(self) -> Dict[str, Any]:
"""Return the schema of the Graph database"""
pass
@abstractmethod
def query(self, query: str, params: dict = {}) -> List[Dict[str, Any]]:
"""Query the graph."""
pass
@abstractmethod
def refresh_schema(self) -> None:
"""Refresh the graph schema information."""
pass
@abstractmethod
def add_graph_documents(
self, graph_documents: List[GraphDocument], include_source: bool = False
) -> None:
"""Take GraphDocument as input as uses it to construct a graph."""
pass

View File

@@ -0,0 +1,228 @@
import hashlib
import sys
from typing import Any, Dict, List, Optional, Union
from langchain_core.utils import get_from_env
from langchain_community.graphs.graph_document import GraphDocument, Node, Relationship
from langchain_community.graphs.graph_store import GraphStore
class GremlinGraph(GraphStore):
"""Gremlin wrapper for graph operations.
Parameters:
url (Optional[str]): The URL of the Gremlin database server or env GREMLIN_URI
username (Optional[str]): The collection-identifier like '/dbs/database/colls/graph'
or env GREMLIN_USERNAME if none provided
password (Optional[str]): The connection-key for database authentication
or env GREMLIN_PASSWORD if none provided
traversal_source (str): The traversal source to use for queries. Defaults to 'g'.
message_serializer (Optional[Any]): The message serializer to use for requests.
Defaults to serializer.GraphSONSerializersV2d0()
include_edge_properties (bool): Whether to include edge properties in
the gremlin graph schema. Defaults to False
*Security note*: Make sure that the database connection uses credentials
that are narrowly-scoped to only include necessary permissions.
Failure to do so may result in data corruption or loss, since the calling
code may attempt commands that would result in deletion, mutation
of data if appropriately prompted or reading sensitive data if such
data is present in the database.
The best way to guard against such negative outcomes is to (as appropriate)
limit the permissions granted to the credentials used with this tool.
See https://python.langchain.com/docs/security for more information.
*Implementation details*:
The Gremlin queries are designed to work with Azure CosmosDB limitations
"""
@property
def get_structured_schema(self) -> Dict[str, Any]:
return self.structured_schema
def __init__(
self,
url: Optional[str] = None,
username: Optional[str] = None,
password: Optional[str] = None,
traversal_source: str = "g",
message_serializer: Optional[Any] = None,
include_edge_properties: bool = False,
) -> None:
"""Create a new Gremlin graph wrapper instance."""
try:
import asyncio
from gremlin_python.driver import client, serializer
if sys.platform == "win32":
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
except ImportError:
raise ImportError(
"Please install gremlin-python first: `pip3 install gremlinpython"
)
self.client = client.Client(
url=get_from_env("url", "GREMLIN_URI", url),
traversal_source=traversal_source,
username=get_from_env("username", "GREMLIN_USERNAME", username),
password=get_from_env("password", "GREMLIN_PASSWORD", password),
message_serializer=message_serializer
if message_serializer
else serializer.GraphSONSerializersV2d0(),
)
self.schema: str = ""
self.include_edge_properties = include_edge_properties
@property
def get_schema(self) -> str:
"""Returns the schema of the Gremlin database"""
if len(self.schema) == 0:
self.refresh_schema()
return self.schema
def refresh_schema(self) -> None:
"""
Refreshes the Gremlin graph schema information.
"""
vertex_schema = self.client.submit("g.V().label().dedup()").all().result()
edge_schema = self.client.submit("g.E().label().dedup()").all().result()
vertex_properties = (
self.client.submit(
"g.V().group().by(label).by(properties().label().dedup().fold())"
)
.all()
.result()[0]
)
self.structured_schema = {
"vertex_labels": vertex_schema,
"edge_labels": edge_schema,
"vertice_props": vertex_properties,
}
self.schema = "\n".join(
[
"Vertex labels are the following:",
",".join(vertex_schema),
"Edge labels are the following:",
",".join(edge_schema),
f"Vertices have following properties:\n{vertex_properties}",
]
)
if self.include_edge_properties:
edge_properties = (
self.client.submit(
"g.E().group().by(label)"
".by(project('inVLabel', 'outVLabel','properties')"
".by(inV().label()).by(outV().label()).by(properties().key().dedup()"
".fold()).dedup().fold())"
)
.all()
.result()[0]
)
self.structured_schema["edge_props"] = edge_properties
self.schema += (
f"\nEdges have the following properties, grouped by label and"
f" the distinct inV and outV labels:\n {edge_properties}"
)
def query(self, query: str, params: dict = {}) -> List[Dict[str, Any]]:
q = self.client.submit(query)
return q.all().result()
def add_graph_documents(
self, graph_documents: List[GraphDocument], include_source: bool = False
) -> None:
"""
Take GraphDocument as input as uses it to construct a graph.
"""
node_cache: Dict[Union[str, int], Node] = {}
for document in graph_documents:
if include_source:
# Create document vertex
doc_props = {
"page_content": document.source.page_content,
"metadata": document.source.metadata,
}
doc_id = hashlib.md5(document.source.page_content.encode()).hexdigest()
doc_node = self.add_node(
Node(id=doc_id, type="Document", properties=doc_props), node_cache
)
# Import nodes to vertices
for n in document.nodes:
node = self.add_node(n)
if include_source:
# Add Edge to document for each node
self.add_edge(
Relationship(
type="contains information about",
source=doc_node,
target=node,
properties={},
)
)
self.add_edge(
Relationship(
type="is extracted from",
source=node,
target=doc_node,
properties={},
)
)
# Edges
for el in document.relationships:
# Find or create the source vertex
self.add_node(el.source, node_cache)
# Find or create the target vertex
self.add_node(el.target, node_cache)
# Find or create the edge
self.add_edge(el)
def build_vertex_query(self, node: Node) -> str:
base_query = (
f"g.V().has('id','{node.id}').fold()"
+ f".coalesce(unfold(),addV('{node.type}')"
+ f".property('id','{node.id}')"
+ f".property('type','{node.type}')"
)
for key, value in node.properties.items():
base_query += f".property('{key}', '{value}')"
return base_query + ")"
def build_edge_query(self, relationship: Relationship) -> str:
source_query = f".has('id','{relationship.source.id}')"
target_query = f".has('id','{relationship.target.id}')"
base_query = f""""g.V(){source_query}.as('a')
.V(){target_query}.as('b')
.choose(
__.inE('{relationship.type}').where(outV().as('a')),
__.identity(),
__.addE('{relationship.type}').from('a').to('b')
)
""".replace("\n", "").replace("\t", "")
for key, value in relationship.properties.items():
base_query += f".property('{key}', '{value}')"
return base_query
def add_node(self, node: Node, node_cache: dict = {}) -> Node:
# if properties does not have label, add type as label
if "label" not in node.properties:
node.properties["label"] = node.type
if node.id in node_cache:
return node_cache[node.id]
else:
query = self.build_vertex_query(node)
_ = self.client.submit(query).all().result()[0]
node_cache[node.id] = node
return node
def add_edge(self, relationship: Relationship) -> Any:
query = self.build_edge_query(relationship)
return self.client.submit(query).all().result()

View File

@@ -0,0 +1,74 @@
from typing import Any, Dict, List
class HugeGraph:
"""HugeGraph wrapper for graph operations.
*Security note*: Make sure that the database connection uses credentials
that are narrowly-scoped to only include necessary permissions.
Failure to do so may result in data corruption or loss, since the calling
code may attempt commands that would result in deletion, mutation
of data if appropriately prompted or reading sensitive data if such
data is present in the database.
The best way to guard against such negative outcomes is to (as appropriate)
limit the permissions granted to the credentials used with this tool.
See https://python.langchain.com/docs/security for more information.
"""
def __init__(
self,
username: str = "default",
password: str = "default",
address: str = "127.0.0.1",
port: int = 8081,
graph: str = "hugegraph",
) -> None:
"""Create a new HugeGraph wrapper instance."""
try:
from hugegraph.connection import PyHugeGraph
except ImportError:
raise ImportError(
"Please install HugeGraph Python client first: "
"`pip3 install hugegraph-python`"
)
self.username = username
self.password = password
self.address = address
self.port = port
self.graph = graph
self.client = PyHugeGraph(
address, port, user=username, pwd=password, graph=graph
)
self.schema = ""
# Set schema
try:
self.refresh_schema()
except Exception as e:
raise ValueError(f"Could not refresh schema. Error: {e}")
@property
def get_schema(self) -> str:
"""Returns the schema of the HugeGraph database"""
return self.schema
def refresh_schema(self) -> None:
"""
Refreshes the HugeGraph schema information.
"""
schema = self.client.schema()
vertex_schema = schema.getVertexLabels()
edge_schema = schema.getEdgeLabels()
relationships = schema.getRelations()
self.schema = (
f"Node properties: {vertex_schema}\n"
f"Edge properties: {edge_schema}\n"
f"Relationships: {relationships}\n"
)
def query(self, query: str) -> List[Dict[str, Any]]:
g = self.client.gremlin()
res = g.exec(query)
return res["data"]

View File

@@ -0,0 +1,99 @@
from typing import Optional, Type
from pydantic import BaseModel
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import BasePromptTemplate
from langchain_core.prompts.prompt import PromptTemplate
from langchain_community.graphs import NetworkxEntityGraph
from langchain_community.graphs.networkx_graph import KG_TRIPLE_DELIMITER
from langchain_community.graphs.networkx_graph import parse_triples
# flake8: noqa
_DEFAULT_KNOWLEDGE_TRIPLE_EXTRACTION_TEMPLATE = (
"You are a networked intelligence helping a human track knowledge triples"
" about all relevant people, things, concepts, etc. and integrating"
" them with your knowledge stored within your weights"
" as well as that stored in a knowledge graph."
" Extract all of the knowledge triples from the text."
" A knowledge triple is a clause that contains a subject, a predicate,"
" and an object. The subject is the entity being described,"
" the predicate is the property of the subject that is being"
" described, and the object is the value of the property.\n\n"
"EXAMPLE\n"
"It's a state in the US. It's also the number 1 producer of gold in the US.\n\n"
f"Output: (Nevada, is a, state){KG_TRIPLE_DELIMITER}(Nevada, is in, US)"
f"{KG_TRIPLE_DELIMITER}(Nevada, is the number 1 producer of, gold)\n"
"END OF EXAMPLE\n\n"
"EXAMPLE\n"
"I'm going to the store.\n\n"
"Output: NONE\n"
"END OF EXAMPLE\n\n"
"EXAMPLE\n"
"Oh huh. I know Descartes likes to drive antique scooters and play the mandolin.\n"
f"Output: (Descartes, likes to drive, antique scooters){KG_TRIPLE_DELIMITER}(Descartes, plays, mandolin)\n"
"END OF EXAMPLE\n\n"
"EXAMPLE\n"
"{text}"
"Output:"
)
KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT = PromptTemplate(
input_variables=["text"],
template=_DEFAULT_KNOWLEDGE_TRIPLE_EXTRACTION_TEMPLATE,
)
class GraphIndexCreator(BaseModel):
"""Functionality to create graph index."""
llm: Optional[BaseLanguageModel] = None
graph_type: Type[NetworkxEntityGraph] = NetworkxEntityGraph
def from_text(
self, text: str, prompt: BasePromptTemplate = KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT
) -> NetworkxEntityGraph:
"""Create graph index from text."""
if self.llm is None:
raise ValueError("llm should not be None")
graph = self.graph_type()
# Temporary local scoped import while community does not depend on
# langchain explicitly
try:
from langchain_classic.chains import LLMChain
except ImportError:
raise ImportError(
"Please install langchain to use this functionality. "
"You can install it with `pip install langchain`."
)
chain = LLMChain(llm=self.llm, prompt=prompt)
output = chain.predict(text=text)
knowledge = parse_triples(output)
for triple in knowledge:
graph.add_triple(triple)
return graph
async def afrom_text(
self, text: str, prompt: BasePromptTemplate = KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT
) -> NetworkxEntityGraph:
"""Create graph index from text asynchronously."""
if self.llm is None:
raise ValueError("llm should not be None")
graph = self.graph_type()
# Temporary local scoped import while community does not depend on
# langchain explicitly
try:
from langchain_classic.chains import LLMChain
except ImportError:
raise ImportError(
"Please install langchain to use this functionality. "
"You can install it with `pip install langchain`."
)
chain = LLMChain(llm=self.llm, prompt=prompt)
output = await chain.apredict(text=text)
knowledge = parse_triples(output)
for triple in knowledge:
graph.add_triple(triple)
return graph

View File

@@ -0,0 +1,264 @@
from hashlib import md5
from typing import Any, Dict, List, Tuple
from langchain_community.graphs.graph_document import GraphDocument, Relationship
class KuzuGraph:
"""Kùzu wrapper for graph operations.
*Security note*: Make sure that the database connection uses credentials
that are narrowly-scoped to only include necessary permissions.
Failure to do so may result in data corruption or loss, since the calling
code may attempt commands that would result in deletion, mutation
of data if appropriately prompted or reading sensitive data if such
data is present in the database.
The best way to guard against such negative outcomes is to (as appropriate)
limit the permissions granted to the credentials used with this tool.
See https://python.langchain.com/docs/security for more information.
"""
def __init__(
self, db: Any, database: str = "kuzu", allow_dangerous_requests: bool = False
) -> None:
"""Initializes the Kùzu graph database connection."""
if allow_dangerous_requests is not True:
raise ValueError(
"The KuzuGraph class is a powerful tool that can be used to execute "
"arbitrary queries on the database. To enable this functionality, "
"set the `allow_dangerous_requests` parameter to `True` when "
"constructing the KuzuGraph object."
)
try:
import kuzu
except ImportError:
raise ImportError(
"Could not import Kùzu python package."
"Please install Kùzu with `pip install kuzu`."
)
self.db = db
self.conn = kuzu.Connection(self.db)
self.database = database
self.refresh_schema()
@property
def get_schema(self) -> str:
"""Returns the schema of the Kùzu database"""
return self.schema
def query(self, query: str, params: dict = {}) -> List[Dict[str, Any]]:
"""Query Kùzu database"""
result = self.conn.execute(query, params)
column_names = result.get_column_names()
return_list = []
while result.has_next():
row = result.get_next()
return_list.append(dict(zip(column_names, row)))
return return_list
def refresh_schema(self) -> None:
"""Refreshes the Kùzu graph schema information"""
node_properties = []
node_table_names = self.conn._get_node_table_names()
for table_name in node_table_names:
current_table_schema = {"properties": [], "label": table_name}
properties = self.conn._get_node_property_names(table_name)
for property_name in properties:
property_type = properties[property_name]["type"]
list_type_flag = ""
if properties[property_name]["dimension"] > 0:
if "shape" in properties[property_name]:
for s in properties[property_name]["shape"]:
list_type_flag += f"[{s}]"
else:
for i in range(properties[property_name]["dimension"]):
list_type_flag += "[]"
property_type += list_type_flag
current_table_schema["properties"].append(
(
property_name,
property_type,
)
)
node_properties.append(current_table_schema)
relationships = []
rel_tables = self.conn._get_rel_table_names()
for table in rel_tables:
relationships.append(
f"(:{table['src']})-[:{table['name']}]->(:{table['dst']})"
)
rel_properties = []
for table in rel_tables:
table_name = table["name"]
current_table_schema = {"properties": [], "label": table_name}
query_result = self.conn.execute(
f"CALL table_info('{table_name}') RETURN *;"
)
while query_result.has_next():
row = query_result.get_next()
prop_name = row[1]
prop_type = row[2]
current_table_schema["properties"].append((prop_name, prop_type))
rel_properties.append(current_table_schema)
self.schema = (
f"Node properties: {node_properties}\n"
f"Relationships properties: {rel_properties}\n"
f"Relationships: {relationships}\n"
)
def _create_chunk_node_table(self) -> None:
self.conn.execute(
"""
CREATE NODE TABLE IF NOT EXISTS Chunk (
id STRING,
text STRING,
type STRING,
PRIMARY KEY(id)
);
"""
)
def _create_entity_node_table(self, node_label: str) -> None:
self.conn.execute(
f"""
CREATE NODE TABLE IF NOT EXISTS {node_label} (
id STRING,
type STRING,
PRIMARY KEY(id)
);
"""
)
def _create_entity_relationship_table(self, rel: Relationship) -> None:
self.conn.execute(
f"""
CREATE REL TABLE IF NOT EXISTS {rel.type} (
FROM {rel.source.type} TO {rel.target.type}
);
"""
)
def add_graph_documents(
self,
graph_documents: List[GraphDocument],
allowed_relationships: List[Tuple[str, str, str]],
include_source: bool = False,
) -> None:
"""
Adds a list of `GraphDocument` objects that represent nodes and relationships
in a graph to a Kùzu backend.
Parameters:
- graph_documents (List[GraphDocument]): A list of `GraphDocument` objects
that contain the nodes and relationships to be added to the graph. Each
`GraphDocument` should encapsulate the structure of part of the graph,
including nodes, relationships, and the source document information.
- allowed_relationships (List[Tuple[str, str, str]]): A list of allowed
relationships that exist in the graph. Each tuple contains three elements:
the source node type, the relationship type, and the target node type.
Required for Kùzu, as the names of the relationship tables that need to
pre-exist are derived from these tuples.
- include_source (bool): If True, stores the source document
and links it to nodes in the graph using the `MENTIONS` relationship.
This is useful for tracing back the origin of data. Merges source
documents based on the `id` property from the source document metadata
if available; otherwise it calculates the MD5 hash of `page_content`
for merging process. Defaults to False.
"""
# Get unique node labels in the graph documents
node_labels = list(
{node.type for document in graph_documents for node in document.nodes}
)
for document in graph_documents:
# Add chunk nodes and create source document relationships if include_source
# is True
if include_source:
self._create_chunk_node_table()
if not document.source.metadata.get("id"):
# Add a unique id to each document chunk via an md5 hash
document.source.metadata["id"] = md5(
document.source.page_content.encode("utf-8")
).hexdigest()
self.conn.execute(
f"""
MERGE (c:Chunk {{id: $id}})
SET c.text = $text,
c.type = "text_chunk"
""", # noqa: F541
parameters={
"id": document.source.metadata["id"],
"text": document.source.page_content,
},
)
for node_label in node_labels:
self._create_entity_node_table(node_label)
# Add entity nodes from data
for node in document.nodes:
self.conn.execute(
f"""
MERGE (e:{node.type} {{id: $id}})
SET e.type = "entity"
""",
parameters={"id": node.id},
)
if include_source:
# If include_source is True, we need to create a relationship table
# between the chunk nodes and the entity nodes
self._create_chunk_node_table()
ddl = "CREATE REL TABLE GROUP IF NOT EXISTS MENTIONS ("
table_names = []
for node_label in node_labels:
table_names.append(f"FROM Chunk TO {node_label}")
table_names = list(set(table_names))
ddl += ", ".join(table_names)
# Add common properties for all the tables here
ddl += ", label STRING, triplet_source_id STRING)"
if ddl:
self.conn.execute(ddl)
# Only allow relationships that exist in the schema
if node.type in node_labels:
self.conn.execute(
f"""
MATCH (c:Chunk {{id: $id}}),
(e:{node.type} {{id: $node_id}})
MERGE (c)-[m:MENTIONS]->(e)
SET m.triplet_source_id = $id
""",
parameters={
"id": document.source.metadata["id"],
"node_id": node.id,
},
)
# Add entity relationships
for rel in document.relationships:
self._create_entity_relationship_table(rel)
# Create relationship
source_label = rel.source.type
source_id = rel.source.id
target_label = rel.target.type
target_id = rel.target.id
self.conn.execute(
f"""
MATCH (e1:{source_label} {{id: $source_id}}),
(e2:{target_label} {{id: $target_id}})
MERGE (e1)-[:{rel.type}]->(e2)
""",
parameters={
"source_id": source_id,
"target_id": target_id,
},
)

View File

@@ -0,0 +1,525 @@
import logging
from hashlib import md5
from typing import Any, Dict, List, Optional
from langchain_core.utils import get_from_dict_or_env
from langchain_community.graphs.graph_document import GraphDocument, Node, Relationship
from langchain_community.graphs.graph_store import GraphStore
logger = logging.getLogger(__name__)
BASE_ENTITY_LABEL = "__Entity__"
SCHEMA_QUERY = """
SHOW SCHEMA INFO
"""
NODE_PROPERTIES_QUERY = """
CALL schema.node_type_properties()
YIELD nodeType AS label, propertyName AS property, propertyTypes AS type
WITH label AS nodeLabels, collect({key: property, types: type}) AS properties
RETURN {labels: nodeLabels, properties: properties} AS output
"""
REL_QUERY = """
MATCH (n)-[e]->(m)
WITH DISTINCT
labels(n) AS start_node_labels,
type(e) AS rel_type,
labels(m) AS end_node_labels,
e,
keys(e) AS properties
UNWIND CASE WHEN size(properties) > 0 THEN properties ELSE [null] END AS prop
WITH
start_node_labels,
rel_type,
end_node_labels,
CASE WHEN prop IS NULL THEN [] ELSE [prop, valueType(e[prop])] END AS property_info
RETURN
start_node_labels,
rel_type,
end_node_labels,
COLLECT(DISTINCT CASE
WHEN property_info <> []
THEN property_info
ELSE null END) AS properties_info
"""
NODE_IMPORT_QUERY = """
UNWIND $data AS row
CALL merge.node(row.label, row.properties, {}, {})
YIELD node
RETURN distinct 'done' AS result
"""
REL_NODES_IMPORT_QUERY = """
UNWIND $data AS row
MERGE (source {id: row.source_id})
MERGE (target {id: row.target_id})
RETURN distinct 'done' AS result
"""
REL_IMPORT_QUERY = """
UNWIND $data AS row
MATCH (source {id: row.source_id})
MATCH (target {id: row.target_id})
WITH source, target, row
CALL merge.relationship(source, row.type, {}, {}, target, {})
YIELD rel
RETURN distinct 'done' AS result
"""
INCLUDE_DOCS_QUERY = """
MERGE (d:Document {id:$document.metadata.id})
SET d.content = $document.page_content
SET d += $document.metadata
RETURN distinct 'done' AS result
"""
INCLUDE_DOCS_SOURCE_QUERY = """
UNWIND $data AS row
MATCH (source {id: row.source_id}), (d:Document {id: $document.metadata.id})
MERGE (d)-[:MENTIONS]->(source)
RETURN distinct 'done' AS result
"""
NODE_PROPS_TEXT = """
Node labels and properties (name and type) are:
"""
REL_PROPS_TEXT = """
Relationship labels and properties are:
"""
REL_TEXT = """
Nodes are connected with the following relationships:
"""
def get_schema_subset(data: Dict[str, Any]) -> Dict[str, Any]:
return {
"edges": [
{
"end_node_labels": edge["end_node_labels"],
"properties": [
{
"key": prop["key"],
"types": [
{"type": type_item["type"].lower()}
for type_item in prop["types"]
],
}
for prop in edge["properties"]
],
"start_node_labels": edge["start_node_labels"],
"type": edge["type"],
}
for edge in data["edges"]
],
"nodes": [
{
"labels": node["labels"],
"properties": [
{
"key": prop["key"],
"types": [
{"type": type_item["type"].lower()}
for type_item in prop["types"]
],
}
for prop in node["properties"]
],
}
for node in data["nodes"]
],
}
def get_reformated_schema(
nodes: List[Dict[str, Any]], rels: List[Dict[str, Any]]
) -> Dict[str, Any]:
return {
"edges": [
{
"end_node_labels": rel["end_node_labels"],
"properties": [
{"key": prop[0], "types": [{"type": prop[1].lower()}]}
for prop in rel["properties_info"]
],
"start_node_labels": rel["start_node_labels"],
"type": rel["rel_type"],
}
for rel in rels
],
"nodes": [
{
"labels": [_remove_backticks(node["labels"])[1:]],
"properties": [
{
"key": prop["key"],
"types": [
{"type": type_item.lower()} for type_item in prop["types"]
],
}
for prop in node["properties"]
if node["properties"][0]["key"] != ""
],
}
for node in nodes
],
}
def transform_schema_to_text(schema: Dict[str, Any]) -> str:
node_props_data = ""
rel_props_data = ""
rel_data = ""
for node in schema["nodes"]:
node_props_data += f"- labels: (:{':'.join(node['labels'])})\n"
if node["properties"] == []:
continue
node_props_data += " properties:\n"
for prop in node["properties"]:
prop_types_str = " or ".join(
{prop_types["type"] for prop_types in prop["types"]}
)
node_props_data += f" - {prop['key']}: {prop_types_str}\n"
for rel in schema["edges"]:
rel_type = rel["type"]
start_labels = ":".join(rel["start_node_labels"])
end_labels = ":".join(rel["end_node_labels"])
rel_data += f"(:{start_labels})-[:{rel_type}]->(:{end_labels})\n"
if rel["properties"] == []:
continue
rel_props_data += f"- labels: {rel_type}\n properties:\n"
for prop in rel["properties"]:
prop_types_str = " or ".join(
{prop_types["type"].lower() for prop_types in prop["types"]}
)
rel_props_data += f" - {prop['key']}: {prop_types_str}\n"
return "".join(
[
NODE_PROPS_TEXT + node_props_data if node_props_data else "",
REL_PROPS_TEXT + rel_props_data if rel_props_data else "",
REL_TEXT + rel_data if rel_data else "",
]
)
def _remove_backticks(text: str) -> str:
return text.replace("`", "")
def _transform_nodes(nodes: list[Node], baseEntityLabel: bool) -> List[dict]:
transformed_nodes = []
for node in nodes:
properties_dict = node.properties | {"id": node.id}
label = (
[_remove_backticks(node.type), BASE_ENTITY_LABEL]
if baseEntityLabel
else [_remove_backticks(node.type)]
)
node_dict = {"label": label, "properties": properties_dict}
transformed_nodes.append(node_dict)
return transformed_nodes
def _transform_relationships(
relationships: list[Relationship], baseEntityLabel: bool
) -> List[dict]:
transformed_relationships = []
for rel in relationships:
rel_dict = {
"type": _remove_backticks(rel.type),
"source_label": (
[BASE_ENTITY_LABEL]
if baseEntityLabel
else [_remove_backticks(rel.source.type)]
),
"source_id": rel.source.id,
"target_label": (
[BASE_ENTITY_LABEL]
if baseEntityLabel
else [_remove_backticks(rel.target.type)]
),
"target_id": rel.target.id,
}
transformed_relationships.append(rel_dict)
return transformed_relationships
class MemgraphGraph(GraphStore):
"""Memgraph wrapper for graph operations.
Parameters:
url (Optional[str]): The URL of the Memgraph database server.
username (Optional[str]): The username for database authentication.
password (Optional[str]): The password for database authentication.
database (str): The name of the database to connect to. Default is 'memgraph'.
refresh_schema (bool): A flag whether to refresh schema information
at initialization. Default is True.
driver_config (Dict): Configuration passed to Neo4j Driver.
*Security note*: Make sure that the database connection uses credentials
that are narrowly-scoped to only include necessary permissions.
Failure to do so may result in data corruption or loss, since the calling
code may attempt commands that would result in deletion, mutation
of data if appropriately prompted or reading sensitive data if such
data is present in the database.
The best way to guard against such negative outcomes is to (as appropriate)
limit the permissions granted to the credentials used with this tool.
See https://python.langchain.com/docs/security for more information.
"""
def __init__(
self,
url: Optional[str] = None,
username: Optional[str] = None,
password: Optional[str] = None,
database: Optional[str] = None,
refresh_schema: bool = True,
*,
driver_config: Optional[Dict] = None,
) -> None:
"""Create a new Memgraph graph wrapper instance."""
try:
import neo4j
except ImportError:
raise ImportError(
"Could not import neo4j python package. "
"Please install it with `pip install neo4j`."
)
url = get_from_dict_or_env({"url": url}, "url", "MEMGRAPH_URI")
# if username and password are "", assume auth is disabled
if username == "" and password == "":
auth = None
else:
username = get_from_dict_or_env(
{"username": username},
"username",
"MEMGRAPH_USERNAME",
)
password = get_from_dict_or_env(
{"password": password},
"password",
"MEMGRAPH_PASSWORD",
)
auth = (username, password)
database = get_from_dict_or_env(
{"database": database}, "database", "MEMGRAPH_DATABASE", "memgraph"
)
self._driver = neo4j.GraphDatabase.driver(
url, auth=auth, **(driver_config or {})
)
self._database = database
self.schema: str = ""
self.structured_schema: Dict[str, Any] = {}
# Verify connection
try:
self._driver.verify_connectivity()
except neo4j.exceptions.ServiceUnavailable:
raise ValueError(
"Could not connect to Memgraph database. "
"Please ensure that the url is correct"
)
except neo4j.exceptions.AuthError:
raise ValueError(
"Could not connect to Memgraph database. "
"Please ensure that the username and password are correct"
)
# Set schema
if refresh_schema:
try:
self.refresh_schema()
except neo4j.exceptions.ClientError as e:
raise e
def close(self) -> None:
if self._driver:
logger.info("Closing the driver connection.")
self._driver.close()
self._driver = None
@property
def get_schema(self) -> str:
"""Returns the schema of the Graph database"""
return self.schema
@property
def get_structured_schema(self) -> Dict[str, Any]:
"""Returns the structured schema of the Graph database"""
return self.structured_schema
def query(self, query: str, params: dict = {}) -> List[Dict[str, Any]]:
"""Query the graph.
Args:
query (str): The Cypher query to execute.
params (dict): The parameters to pass to the query.
Returns:
List[Dict[str, Any]]: The list of dictionaries containing the query results.
"""
from neo4j.exceptions import Neo4jError
try:
data, _, _ = self._driver.execute_query(
query,
database_=self._database,
parameters_=params,
)
json_data = [r.data() for r in data]
return json_data
except Neo4jError as e:
if not (
(
( # isCallInTransactionError
e.code == "Neo.DatabaseError.Statement.ExecutionFailed"
or e.code
== "Neo.DatabaseError.Transaction.TransactionStartFailed"
)
and "in an implicit transaction" in e.message
)
or ( # isPeriodicCommitError
e.code == "Neo.ClientError.Statement.SemanticError"
and (
"in an open transaction is not possible" in e.message
or "tried to execute in an explicit transaction" in e.message
)
)
or (
e.code == "Memgraph.ClientError.MemgraphError.MemgraphError"
and ("in multicommand transactions" in e.message)
)
or (
e.code == "Memgraph.ClientError.MemgraphError.MemgraphError"
and "SchemaInfo disabled" in e.message
)
):
raise
# fallback to allow implicit transactions
with self._driver.session(database=self._database) as session:
data = session.run(query, params)
json_data = [r.data() for r in data]
return json_data
def refresh_schema(self) -> None:
"""
Refreshes the Memgraph graph schema information.
"""
import ast
from neo4j.exceptions import Neo4jError
# leave schema empty if db is empty
if self.query("MATCH (n) RETURN n LIMIT 1") == []:
return
# first try with SHOW SCHEMA INFO
try:
result = self.query(SCHEMA_QUERY)[0].get("schema")
if result is not None and isinstance(result, (str, ast.AST)):
schema_result = ast.literal_eval(result)
else:
schema_result = result
assert schema_result is not None
structured_schema = get_schema_subset(schema_result)
self.structured_schema = structured_schema
self.schema = transform_schema_to_text(structured_schema)
return
except Neo4jError as e:
if (
e.code == "Memgraph.ClientError.MemgraphError.MemgraphError"
and "SchemaInfo disabled" in e.message
):
logger.info(
"Schema generation with SHOW SCHEMA INFO query failed. "
"Set --schema-info-enabled=true to use SHOW SCHEMA INFO query. "
"Falling back to alternative queries."
)
# fallback on Cypher without SHOW SCHEMA INFO
nodes = [query["output"] for query in self.query(NODE_PROPERTIES_QUERY)]
rels = self.query(REL_QUERY)
structured_schema = get_reformated_schema(nodes, rels)
self.structured_schema = structured_schema
self.schema = transform_schema_to_text(structured_schema)
def add_graph_documents(
self,
graph_documents: List[GraphDocument],
include_source: bool = False,
baseEntityLabel: bool = False,
) -> None:
"""
Take GraphDocument as input as uses it to construct a graph in Memgraph.
Parameters:
- graph_documents (List[GraphDocument]): A list of GraphDocument objects
that contain the nodes and relationships to be added to the graph. Each
GraphDocument should encapsulate the structure of part of the graph,
including nodes, relationships, and the source document information.
- include_source (bool, optional): If True, stores the source document
and links it to nodes in the graph using the MENTIONS relationship.
This is useful for tracing back the origin of data. Merges source
documents based on the `id` property from the source document metadata
if available; otherwise it calculates the MD5 hash of `page_content`
for merging process. Defaults to False.
- baseEntityLabel (bool, optional): If True, each newly created node
gets a secondary __Entity__ label, which is indexed and improves import
speed and performance. Defaults to False.
"""
if baseEntityLabel:
self.query(
f"CREATE CONSTRAINT ON (b:{BASE_ENTITY_LABEL}) ASSERT b.id IS UNIQUE;"
)
self.query(f"CREATE INDEX ON :{BASE_ENTITY_LABEL}(id);")
self.query(f"CREATE INDEX ON :{BASE_ENTITY_LABEL};")
for document in graph_documents:
if include_source:
if not document.source.metadata.get("id"):
document.source.metadata["id"] = md5(
document.source.page_content.encode("utf-8")
).hexdigest()
self.query(INCLUDE_DOCS_QUERY, {"document": document.source.__dict__})
self.query(
NODE_IMPORT_QUERY,
{"data": _transform_nodes(document.nodes, baseEntityLabel)},
)
rel_data = _transform_relationships(document.relationships, baseEntityLabel)
self.query(
REL_NODES_IMPORT_QUERY,
{"data": rel_data},
)
self.query(
REL_IMPORT_QUERY,
{"data": rel_data},
)
if include_source:
self.query(
INCLUDE_DOCS_SOURCE_QUERY,
{"data": rel_data, "document": document.source.__dict__},
)
self.refresh_schema()

View File

@@ -0,0 +1,222 @@
import logging
from string import Template
from typing import Any, Dict, Optional
logger = logging.getLogger(__name__)
rel_query = Template(
"""
MATCH ()-[e:`$edge_type`]->()
WITH e limit 1
MATCH (m)-[:`$edge_type`]->(n) WHERE id(m) == src(e) AND id(n) == dst(e)
RETURN "(:" + tags(m)[0] + ")-[:$edge_type]->(:" + tags(n)[0] + ")" AS rels
"""
)
RETRY_TIMES = 3
class NebulaGraph:
"""NebulaGraph wrapper for graph operations.
NebulaGraph inherits methods from Neo4jGraph to bring ease to the user space.
*Security note*: Make sure that the database connection uses credentials
that are narrowly-scoped to only include necessary permissions.
Failure to do so may result in data corruption or loss, since the calling
code may attempt commands that would result in deletion, mutation
of data if appropriately prompted or reading sensitive data if such
data is present in the database.
The best way to guard against such negative outcomes is to (as appropriate)
limit the permissions granted to the credentials used with this tool.
See https://python.langchain.com/docs/security for more information.
"""
def __init__(
self,
space: str,
username: str = "root",
password: str = "nebula",
address: str = "127.0.0.1",
port: int = 9669,
session_pool_size: int = 30,
) -> None:
"""Create a new NebulaGraph wrapper instance."""
try:
import nebula3 # noqa: F401
import pandas # noqa: F401
except ImportError:
raise ImportError(
"Please install NebulaGraph Python client and pandas first: "
"`pip install nebula3-python pandas`"
)
self.username = username
self.password = password
self.address = address
self.port = port
self.space = space
self.session_pool_size = session_pool_size
self.session_pool = self._get_session_pool()
self.schema = ""
# Set schema
try:
self.refresh_schema()
except Exception as e:
raise ValueError(f"Could not refresh schema. Error: {e}")
def _get_session_pool(self) -> Any:
assert all(
[
self.username,
self.password,
self.address,
self.port,
self.space,
]
), (
"Please provide all of the following parameters: "
"username, password, address, port, space"
)
from nebula3.Config import SessionPoolConfig
from nebula3.Exception import AuthFailedException, InValidHostname
from nebula3.gclient.net.SessionPool import SessionPool
config = SessionPoolConfig()
config.max_size = self.session_pool_size
try:
session_pool = SessionPool(
self.username,
self.password,
self.space,
[(self.address, self.port)],
)
except InValidHostname:
raise ValueError(
"Could not connect to NebulaGraph database. "
"Please ensure that the address and port are correct"
)
try:
session_pool.init(config)
except AuthFailedException:
raise ValueError(
"Could not connect to NebulaGraph database. "
"Please ensure that the username and password are correct"
)
except RuntimeError as e:
raise ValueError(f"Error initializing session pool. Error: {e}")
return session_pool
def __del__(self) -> None:
try:
self.session_pool.close()
except Exception as e:
logger.warning(f"Could not close session pool. Error: {e}")
@property
def get_schema(self) -> str:
"""Returns the schema of the NebulaGraph database"""
return self.schema
def execute(self, query: str, params: Optional[dict] = None, retry: int = 0) -> Any:
"""Query NebulaGraph database."""
from nebula3.Exception import IOErrorException, NoValidSessionException
from nebula3.fbthrift.transport.TTransport import TTransportException
params = params or {}
try:
result = self.session_pool.execute_parameter(query, params)
if not result.is_succeeded():
logger.warning(
f"Error executing query to NebulaGraph. "
f"Error: {result.error_msg()}\n"
f"Query: {query} \n"
)
return result
except NoValidSessionException:
logger.warning(
f"No valid session found in session pool. "
f"Please consider increasing the session pool size. "
f"Current size: {self.session_pool_size}"
)
raise ValueError(
f"No valid session found in session pool. "
f"Please consider increasing the session pool size. "
f"Current size: {self.session_pool_size}"
)
except RuntimeError as e:
if retry < RETRY_TIMES:
retry += 1
logger.warning(
f"Error executing query to NebulaGraph. "
f"Retrying ({retry}/{RETRY_TIMES})...\n"
f"query: {query} \n"
f"Error: {e}"
)
return self.execute(query, params, retry)
else:
raise ValueError(f"Error executing query to NebulaGraph. Error: {e}")
except (TTransportException, IOErrorException):
# connection issue, try to recreate session pool
if retry < RETRY_TIMES:
retry += 1
logger.warning(
f"Connection issue with NebulaGraph. "
f"Retrying ({retry}/{RETRY_TIMES})...\n to recreate session pool"
)
self.session_pool = self._get_session_pool()
return self.execute(query, params, retry)
def refresh_schema(self) -> None:
"""
Refreshes the NebulaGraph schema information.
"""
tags_schema, edge_types_schema, relationships = [], [], []
for tag in self.execute("SHOW TAGS").column_values("Name"):
tag_name = tag.cast()
tag_schema = {"tag": tag_name, "properties": []}
r = self.execute(f"DESCRIBE TAG `{tag_name}`")
props, types = r.column_values("Field"), r.column_values("Type")
for i in range(r.row_size()):
tag_schema["properties"].append((props[i].cast(), types[i].cast()))
tags_schema.append(tag_schema)
for edge_type in self.execute("SHOW EDGES").column_values("Name"):
edge_type_name = edge_type.cast()
edge_schema = {"edge": edge_type_name, "properties": []}
r = self.execute(f"DESCRIBE EDGE `{edge_type_name}`")
props, types = r.column_values("Field"), r.column_values("Type")
for i in range(r.row_size()):
edge_schema["properties"].append((props[i].cast(), types[i].cast()))
edge_types_schema.append(edge_schema)
# build relationships types
r = self.execute(
rel_query.substitute(edge_type=edge_type_name)
).column_values("rels")
if len(r) > 0:
relationships.append(r[0].cast())
self.schema = (
f"Node properties: {tags_schema}\n"
f"Edge properties: {edge_types_schema}\n"
f"Relationships: {relationships}\n"
)
def query(self, query: str, retry: int = 0) -> Dict[str, Any]:
result = self.execute(query, retry=retry)
columns = result.keys()
d: Dict[str, list] = {}
for col_num in range(result.col_size()):
col_name = columns[col_num]
col_list = result.column_values(col_name)
d[col_name] = [x.cast() for x in col_list]
return d

View File

@@ -0,0 +1,848 @@
from hashlib import md5
from typing import Any, Dict, List, Optional
from langchain_core._api.deprecation import deprecated
from langchain_core.utils import get_from_dict_or_env
from langchain_community.graphs.graph_document import GraphDocument
from langchain_community.graphs.graph_store import GraphStore
BASE_ENTITY_LABEL = "__Entity__"
EXCLUDED_LABELS = ["_Bloom_Perspective_", "_Bloom_Scene_"]
EXCLUDED_RELS = ["_Bloom_HAS_SCENE_"]
EXHAUSTIVE_SEARCH_LIMIT = 10000
LIST_LIMIT = 128
# Threshold for returning all available prop values in graph schema
DISTINCT_VALUE_LIMIT = 10
node_properties_query = """
CALL apoc.meta.data()
YIELD label, other, elementType, type, property
WHERE NOT type = "RELATIONSHIP" AND elementType = "node"
AND NOT label IN $EXCLUDED_LABELS
WITH label AS nodeLabels, collect({property:property, type:type}) AS properties
RETURN {labels: nodeLabels, properties: properties} AS output
"""
rel_properties_query = """
CALL apoc.meta.data()
YIELD label, other, elementType, type, property
WHERE NOT type = "RELATIONSHIP" AND elementType = "relationship"
AND NOT label in $EXCLUDED_LABELS
WITH label AS nodeLabels, collect({property:property, type:type}) AS properties
RETURN {type: nodeLabels, properties: properties} AS output
"""
rel_query = """
CALL apoc.meta.data()
YIELD label, other, elementType, type, property
WHERE type = "RELATIONSHIP" AND elementType = "node"
UNWIND other AS other_node
WITH * WHERE NOT label IN $EXCLUDED_LABELS
AND NOT other_node IN $EXCLUDED_LABELS
RETURN {start: label, type: property, end: toString(other_node)} AS output
"""
include_docs_query = (
"MERGE (d:Document {id:$document.metadata.id}) "
"SET d.text = $document.page_content "
"SET d += $document.metadata "
"WITH d "
)
@deprecated(
since="0.3.8",
removal="1.0",
alternative_import="langchain_neo4j.graphs.neo4j_graph.clean_string_values",
)
def clean_string_values(text: str) -> str:
"""Clean string values for schema.
Cleans the input text by replacing newline and carriage return characters.
Args:
text (str): The input text to clean.
Returns:
str: The cleaned text.
"""
return text.replace("\n", " ").replace("\r", " ")
@deprecated(
since="0.3.8",
removal="1.0",
alternative_import="langchain_neo4j.graphs.neo4j_graph.value_sanitize",
)
def value_sanitize(d: Any) -> Any:
"""Sanitize the input dictionary or list.
Sanitizes the input by removing embedding-like values,
lists with more than 128 elements, that are mostly irrelevant for
generating answers in a LLM context. These properties, if left in
results, can occupy significant context space and detract from
the LLM's performance by introducing unnecessary noise and cost.
Args:
d (Any): The input dictionary or list to sanitize.
Returns:
Any: The sanitized dictionary or list.
"""
if isinstance(d, dict):
new_dict = {}
for key, value in d.items():
if isinstance(value, dict):
sanitized_value = value_sanitize(value)
if (
sanitized_value is not None
): # Check if the sanitized value is not None
new_dict[key] = sanitized_value
elif isinstance(value, list):
if len(value) < LIST_LIMIT:
sanitized_value = value_sanitize(value)
if (
sanitized_value is not None
): # Check if the sanitized value is not None
new_dict[key] = sanitized_value
# Do not include the key if the list is oversized
else:
new_dict[key] = value
return new_dict
elif isinstance(d, list):
if len(d) < LIST_LIMIT:
return [
value_sanitize(item) for item in d if value_sanitize(item) is not None
]
else:
return None
else:
return d
@deprecated(
since="0.3.8",
removal="1.0",
alternative_import="langchain_neo4j.graphs.neo4j_graph._get_node_import_query",
)
def _get_node_import_query(baseEntityLabel: bool, include_source: bool) -> str:
if baseEntityLabel:
return (
f"{include_docs_query if include_source else ''}"
"UNWIND $data AS row "
f"MERGE (source:`{BASE_ENTITY_LABEL}` {{id: row.id}}) "
"SET source += row.properties "
f"{'MERGE (d)-[:MENTIONS]->(source) ' if include_source else ''}"
"WITH source, row "
"CALL apoc.create.addLabels( source, [row.type] ) YIELD node "
"RETURN distinct 'done' AS result"
)
else:
return (
f"{include_docs_query if include_source else ''}"
"UNWIND $data AS row "
"CALL apoc.merge.node([row.type], {id: row.id}, "
"row.properties, {}) YIELD node "
f"{'MERGE (d)-[:MENTIONS]->(node) ' if include_source else ''}"
"RETURN distinct 'done' AS result"
)
@deprecated(
since="0.3.8",
removal="1.0",
alternative_import="langchain_neo4j.graphs.neo4j_graph._get_rel_import_query",
)
def _get_rel_import_query(baseEntityLabel: bool) -> str:
if baseEntityLabel:
return (
"UNWIND $data AS row "
f"MERGE (source:`{BASE_ENTITY_LABEL}` {{id: row.source}}) "
f"MERGE (target:`{BASE_ENTITY_LABEL}` {{id: row.target}}) "
"WITH source, target, row "
"CALL apoc.merge.relationship(source, row.type, "
"{}, row.properties, target) YIELD rel "
"RETURN distinct 'done'"
)
else:
return (
"UNWIND $data AS row "
"CALL apoc.merge.node([row.source_label], {id: row.source},"
"{}, {}) YIELD node as source "
"CALL apoc.merge.node([row.target_label], {id: row.target},"
"{}, {}) YIELD node as target "
"CALL apoc.merge.relationship(source, row.type, "
"{}, row.properties, target) YIELD rel "
"RETURN distinct 'done'"
)
@deprecated(
since="0.3.8",
removal="1.0",
alternative_import="langchain_neo4j.graphs.neo4j_graph._format_schema",
)
def _format_schema(schema: Dict, is_enhanced: bool) -> str:
formatted_node_props = []
formatted_rel_props = []
if is_enhanced:
# Enhanced formatting for nodes
for node_type, properties in schema["node_props"].items():
formatted_node_props.append(f"- **{node_type}**")
for prop in properties:
example = ""
if prop["type"] == "STRING" and prop.get("values"):
if prop.get("distinct_count", 11) > DISTINCT_VALUE_LIMIT:
example = (
f'Example: "{clean_string_values(prop["values"][0])}"'
if prop["values"]
else ""
)
else: # If less than 10 possible values return all
example = (
(
"Available options: "
f"{[clean_string_values(el) for el in prop['values']]}"
)
if prop["values"]
else ""
)
elif prop["type"] in [
"INTEGER",
"FLOAT",
"DATE",
"DATE_TIME",
"LOCAL_DATE_TIME",
]:
if prop.get("min") is not None:
example = f"Min: {prop['min']}, Max: {prop['max']}"
else:
example = (
f'Example: "{prop["values"][0]}"'
if prop.get("values")
else ""
)
elif prop["type"] == "LIST":
# Skip embeddings
if not prop.get("min_size") or prop["min_size"] > LIST_LIMIT:
continue
example = (
f"Min Size: {prop['min_size']}, Max Size: {prop['max_size']}"
)
formatted_node_props.append(
f" - `{prop['property']}`: {prop['type']} {example}"
)
# Enhanced formatting for relationships
for rel_type, properties in schema["rel_props"].items():
formatted_rel_props.append(f"- **{rel_type}**")
for prop in properties:
example = ""
if prop["type"] == "STRING":
if prop.get("distinct_count", 11) > DISTINCT_VALUE_LIMIT:
example = (
f'Example: "{clean_string_values(prop["values"][0])}"'
if prop["values"]
else ""
)
else: # If less than 10 possible values return all
example = (
(
"Available options: "
f"{[clean_string_values(el) for el in prop['values']]}"
)
if prop["values"]
else ""
)
elif prop["type"] in [
"INTEGER",
"FLOAT",
"DATE",
"DATE_TIME",
"LOCAL_DATE_TIME",
]:
if prop.get("min"): # If we have min/max
example = f"Min: {prop['min']}, Max: {prop['max']}"
else: # return a single value
example = (
f'Example: "{prop["values"][0]}"' if prop["values"] else ""
)
elif prop["type"] == "LIST":
# Skip embeddings
if not prop.get("min_size") or prop["min_size"] > LIST_LIMIT:
continue
example = (
f"Min Size: {prop['min_size']}, Max Size: {prop['max_size']}"
)
formatted_rel_props.append(
f" - `{prop['property']}: {prop['type']}` {example}"
)
else:
# Format node properties
for label, props in schema["node_props"].items():
props_str = ", ".join(
[f"{prop['property']}: {prop['type']}" for prop in props]
)
formatted_node_props.append(f"{label} {{{props_str}}}")
# Format relationship properties using structured_schema
for type, props in schema["rel_props"].items():
props_str = ", ".join(
[f"{prop['property']}: {prop['type']}" for prop in props]
)
formatted_rel_props.append(f"{type} {{{props_str}}}")
# Format relationships
formatted_rels = [
f"(:{el['start']})-[:{el['type']}]->(:{el['end']})"
for el in schema["relationships"]
]
return "\n".join(
[
"Node properties:",
"\n".join(formatted_node_props),
"Relationship properties:",
"\n".join(formatted_rel_props),
"The relationships:",
"\n".join(formatted_rels),
]
)
@deprecated(
since="0.3.8",
removal="1.0",
alternative_import="langchain_neo4j.graphs.neo4j_graph._remove_backticks",
)
def _remove_backticks(text: str) -> str:
return text.replace("`", "")
@deprecated(
since="0.3.8",
removal="1.0",
alternative_import="langchain_neo4j.Neo4jGraph",
)
class Neo4jGraph(GraphStore):
"""Neo4j database wrapper for various graph operations.
Parameters:
url (Optional[str]): The URL of the Neo4j database server.
username (Optional[str]): The username for database authentication.
password (Optional[str]): The password for database authentication.
database (str): The name of the database to connect to. Default is 'neo4j'.
timeout (Optional[float]): The timeout for transactions in seconds.
Useful for terminating long-running queries.
By default, there is no timeout set.
sanitize (bool): A flag to indicate whether to remove lists with
more than 128 elements from results. Useful for removing
embedding-like properties from database responses. Default is False.
refresh_schema (bool): A flag whether to refresh schema information
at initialization. Default is True.
enhanced_schema (bool): A flag whether to scan the database for
example values and use them in the graph schema. Default is False.
driver_config (Dict): Configuration passed to Neo4j Driver.
*Security note*: Make sure that the database connection uses credentials
that are narrowly-scoped to only include necessary permissions.
Failure to do so may result in data corruption or loss, since the calling
code may attempt commands that would result in deletion, mutation
of data if appropriately prompted or reading sensitive data if such
data is present in the database.
The best way to guard against such negative outcomes is to (as appropriate)
limit the permissions granted to the credentials used with this tool.
See https://python.langchain.com/docs/security for more information.
"""
def __init__(
self,
url: Optional[str] = None,
username: Optional[str] = None,
password: Optional[str] = None,
database: Optional[str] = None,
timeout: Optional[float] = None,
sanitize: bool = False,
refresh_schema: bool = True,
*,
driver_config: Optional[Dict] = None,
enhanced_schema: bool = False,
) -> None:
"""Create a new Neo4j graph wrapper instance."""
try:
import neo4j
except ImportError:
raise ImportError(
"Could not import neo4j python package. "
"Please install it with `pip install neo4j`."
)
url = get_from_dict_or_env({"url": url}, "url", "NEO4J_URI")
# if username and password are "", assume Neo4j auth is disabled
if username == "" and password == "":
auth = None
else:
username = get_from_dict_or_env(
{"username": username},
"username",
"NEO4J_USERNAME",
)
password = get_from_dict_or_env(
{"password": password},
"password",
"NEO4J_PASSWORD",
)
auth = (username, password)
database = get_from_dict_or_env(
{"database": database}, "database", "NEO4J_DATABASE", "neo4j"
)
self._driver = neo4j.GraphDatabase.driver(
url, auth=auth, **(driver_config or {})
)
self._database = database
self.timeout = timeout
self.sanitize = sanitize
self._enhanced_schema = enhanced_schema
self.schema: str = ""
self.structured_schema: Dict[str, Any] = {}
# Verify connection
try:
self._driver.verify_connectivity()
except neo4j.exceptions.ServiceUnavailable:
raise ValueError(
"Could not connect to Neo4j database. "
"Please ensure that the url is correct"
)
except neo4j.exceptions.AuthError:
raise ValueError(
"Could not connect to Neo4j database. "
"Please ensure that the username and password are correct"
)
# Set schema
if refresh_schema:
try:
self.refresh_schema()
except neo4j.exceptions.ClientError as e:
if e.code == "Neo.ClientError.Procedure.ProcedureNotFound":
raise ValueError(
"Could not use APOC procedures. "
"Please ensure the APOC plugin is installed in Neo4j and that "
"'apoc.meta.data()' is allowed in Neo4j configuration "
)
raise e
@property
def get_schema(self) -> str:
"""Returns the schema of the Graph"""
return self.schema
@property
def get_structured_schema(self) -> Dict[str, Any]:
"""Returns the structured schema of the Graph"""
return self.structured_schema
def query(
self,
query: str,
params: dict = {},
) -> List[Dict[str, Any]]:
"""Query Neo4j database.
Args:
query (str): The Cypher query to execute.
params (dict): The parameters to pass to the query.
Returns:
List[Dict[str, Any]]: The list of dictionaries containing the query results.
"""
from neo4j import Query
from neo4j.exceptions import Neo4jError
try:
data, _, _ = self._driver.execute_query(
Query(text=query, timeout=self.timeout),
database_=self._database,
parameters_=params,
)
json_data = [r.data() for r in data]
if self.sanitize:
json_data = [value_sanitize(el) for el in json_data]
return json_data
except Neo4jError as e:
if not (
(
( # isCallInTransactionError
e.code == "Neo.DatabaseError.Statement.ExecutionFailed"
or e.code
== "Neo.DatabaseError.Transaction.TransactionStartFailed"
)
and "in an implicit transaction" in e.message
)
or ( # isPeriodicCommitError
e.code == "Neo.ClientError.Statement.SemanticError"
and (
"in an open transaction is not possible" in e.message
or "tried to execute in an explicit transaction" in e.message
)
)
):
raise
# fallback to allow implicit transactions
with self._driver.session(database=self._database) as session:
data = session.run(Query(text=query, timeout=self.timeout), params)
json_data = [r.data() for r in data]
if self.sanitize:
json_data = [value_sanitize(el) for el in json_data]
return json_data
def refresh_schema(self) -> None:
"""
Refreshes the Neo4j graph schema information.
"""
from neo4j.exceptions import ClientError, CypherTypeError
node_properties = [
el["output"]
for el in self.query(
node_properties_query,
params={"EXCLUDED_LABELS": EXCLUDED_LABELS + [BASE_ENTITY_LABEL]},
)
]
rel_properties = [
el["output"]
for el in self.query(
rel_properties_query, params={"EXCLUDED_LABELS": EXCLUDED_RELS}
)
]
relationships = [
el["output"]
for el in self.query(
rel_query,
params={"EXCLUDED_LABELS": EXCLUDED_LABELS + [BASE_ENTITY_LABEL]},
)
]
# Get constraints & indexes
try:
constraint = self.query("SHOW CONSTRAINTS")
index = self.query(
"CALL apoc.schema.nodes() YIELD label, properties, type, size, "
"valuesSelectivity WHERE type = 'RANGE' RETURN *, "
"size * valuesSelectivity as distinctValues"
)
except (
ClientError
): # Read-only user might not have access to schema information
constraint = []
index = []
self.structured_schema = {
"node_props": {el["labels"]: el["properties"] for el in node_properties},
"rel_props": {el["type"]: el["properties"] for el in rel_properties},
"relationships": relationships,
"metadata": {"constraint": constraint, "index": index},
}
if self._enhanced_schema:
schema_counts = self.query(
"CALL apoc.meta.graphSample() YIELD nodes, relationships "
"RETURN nodes, [rel in relationships | {name:apoc.any.property"
"(rel, 'type'), count: apoc.any.property(rel, 'count')}]"
" AS relationships"
)
# Update node info
for node in schema_counts[0]["nodes"]:
# Skip bloom labels
if node["name"] in EXCLUDED_LABELS:
continue
node_props = self.structured_schema["node_props"].get(node["name"])
if not node_props: # The node has no properties
continue
enhanced_cypher = self._enhanced_schema_cypher(
node["name"], node_props, node["count"] < EXHAUSTIVE_SEARCH_LIMIT
)
# Due to schema-flexible nature of neo4j errors can happen
try:
enhanced_info = self.query(enhanced_cypher)[0]["output"]
for prop in node_props:
if prop["property"] in enhanced_info:
prop.update(enhanced_info[prop["property"]])
except CypherTypeError:
continue
# Update rel info
for rel in schema_counts[0]["relationships"]:
# Skip bloom labels
if rel["name"] in EXCLUDED_RELS:
continue
rel_props = self.structured_schema["rel_props"].get(rel["name"])
if not rel_props: # The rel has no properties
continue
enhanced_cypher = self._enhanced_schema_cypher(
rel["name"],
rel_props,
rel["count"] < EXHAUSTIVE_SEARCH_LIMIT,
is_relationship=True,
)
try:
enhanced_info = self.query(enhanced_cypher)[0]["output"]
for prop in rel_props:
if prop["property"] in enhanced_info:
prop.update(enhanced_info[prop["property"]])
# Due to schema-flexible nature of neo4j errors can happen
except CypherTypeError:
continue
schema = _format_schema(self.structured_schema, self._enhanced_schema)
self.schema = schema
def add_graph_documents(
self,
graph_documents: List[GraphDocument],
include_source: bool = False,
baseEntityLabel: bool = False,
) -> None:
"""
This method constructs nodes and relationships in the graph based on the
provided GraphDocument objects.
Parameters:
- graph_documents (List[GraphDocument]): A list of GraphDocument objects
that contain the nodes and relationships to be added to the graph. Each
GraphDocument should encapsulate the structure of part of the graph,
including nodes, relationships, and the source document information.
- include_source (bool, optional): If True, stores the source document
and links it to nodes in the graph using the MENTIONS relationship.
This is useful for tracing back the origin of data. Merges source
documents based on the `id` property from the source document metadata
if available; otherwise it calculates the MD5 hash of `page_content`
for merging process. Defaults to False.
- baseEntityLabel (bool, optional): If True, each newly created node
gets a secondary __Entity__ label, which is indexed and improves import
speed and performance. Defaults to False.
"""
if baseEntityLabel: # Check if constraint already exists
constraint_exists = any(
[
el["labelsOrTypes"] == [BASE_ENTITY_LABEL]
and el["properties"] == ["id"]
for el in self.structured_schema.get("metadata", {}).get(
"constraint", []
)
]
)
if not constraint_exists:
# Create constraint
self.query(
f"CREATE CONSTRAINT IF NOT EXISTS FOR (b:{BASE_ENTITY_LABEL}) "
"REQUIRE b.id IS UNIQUE;"
)
self.refresh_schema() # Refresh constraint information
node_import_query = _get_node_import_query(baseEntityLabel, include_source)
rel_import_query = _get_rel_import_query(baseEntityLabel)
for document in graph_documents:
if not document.source.metadata.get("id"):
document.source.metadata["id"] = md5(
document.source.page_content.encode("utf-8")
).hexdigest()
# Remove backticks from node types
for node in document.nodes:
node.type = _remove_backticks(node.type)
# Import nodes
self.query(
node_import_query,
{
"data": [el.__dict__ for el in document.nodes],
"document": document.source.__dict__,
},
)
# Import relationships
self.query(
rel_import_query,
{
"data": [
{
"source": el.source.id,
"source_label": _remove_backticks(el.source.type),
"target": el.target.id,
"target_label": _remove_backticks(el.target.type),
"type": _remove_backticks(
el.type.replace(" ", "_").upper()
),
"properties": el.properties,
}
for el in document.relationships
]
},
)
def _enhanced_schema_cypher(
self,
label_or_type: str,
properties: List[Dict[str, Any]],
exhaustive: bool,
is_relationship: bool = False,
) -> str:
if is_relationship:
match_clause = f"MATCH ()-[n:`{label_or_type}`]->()"
else:
match_clause = f"MATCH (n:`{label_or_type}`)"
with_clauses = []
return_clauses = []
output_dict = {}
if exhaustive:
for prop in properties:
prop_name = prop["property"]
prop_type = prop["type"]
if prop_type == "STRING":
with_clauses.append(
(
f"collect(distinct substring(toString(n.`{prop_name}`)"
f", 0, 50)) AS `{prop_name}_values`"
)
)
return_clauses.append(
(
f"values:`{prop_name}_values`[..{DISTINCT_VALUE_LIMIT}],"
f" distinct_count: size(`{prop_name}_values`)"
)
)
elif prop_type in [
"INTEGER",
"FLOAT",
"DATE",
"DATE_TIME",
"LOCAL_DATE_TIME",
]:
with_clauses.append(f"min(n.`{prop_name}`) AS `{prop_name}_min`")
with_clauses.append(f"max(n.`{prop_name}`) AS `{prop_name}_max`")
with_clauses.append(
f"count(distinct n.`{prop_name}`) AS `{prop_name}_distinct`"
)
return_clauses.append(
(
f"min: toString(`{prop_name}_min`), "
f"max: toString(`{prop_name}_max`), "
f"distinct_count: `{prop_name}_distinct`"
)
)
elif prop_type == "LIST":
with_clauses.append(
(
f"min(size(n.`{prop_name}`)) AS `{prop_name}_size_min`, "
f"max(size(n.`{prop_name}`)) AS `{prop_name}_size_max`"
)
)
return_clauses.append(
f"min_size: `{prop_name}_size_min`, "
f"max_size: `{prop_name}_size_max`"
)
elif prop_type in ["BOOLEAN", "POINT", "DURATION"]:
continue
output_dict[prop_name] = "{" + return_clauses.pop() + "}"
else:
# Just sample 5 random nodes
match_clause += " WITH n LIMIT 5"
for prop in properties:
prop_name = prop["property"]
prop_type = prop["type"]
# Check if indexed property, we can still do exhaustive
prop_index = [
el
for el in self.structured_schema["metadata"]["index"]
if el["label"] == label_or_type
and el["properties"] == [prop_name]
and el["type"] == "RANGE"
]
if prop_type == "STRING":
if (
prop_index
and prop_index[0].get("size") > 0
and prop_index[0].get("distinctValues") <= DISTINCT_VALUE_LIMIT
):
distinct_values = self.query(
f"CALL apoc.schema.properties.distinct("
f"'{label_or_type}', '{prop_name}') YIELD value"
)[0]["value"]
return_clauses.append(
(
f"values: {distinct_values},"
f" distinct_count: {len(distinct_values)}"
)
)
else:
with_clauses.append(
(
f"collect(distinct substring(toString(n.`{prop_name}`)"
f", 0, 50)) AS `{prop_name}_values`"
)
)
return_clauses.append(f"values: `{prop_name}_values`")
elif prop_type in [
"INTEGER",
"FLOAT",
"DATE",
"DATE_TIME",
"LOCAL_DATE_TIME",
]:
if not prop_index:
with_clauses.append(
f"collect(distinct toString(n.`{prop_name}`)) "
f"AS `{prop_name}_values`"
)
return_clauses.append(f"values: `{prop_name}_values`")
else:
with_clauses.append(
f"min(n.`{prop_name}`) AS `{prop_name}_min`"
)
with_clauses.append(
f"max(n.`{prop_name}`) AS `{prop_name}_max`"
)
with_clauses.append(
f"count(distinct n.`{prop_name}`) AS `{prop_name}_distinct`"
)
return_clauses.append(
(
f"min: toString(`{prop_name}_min`), "
f"max: toString(`{prop_name}_max`), "
f"distinct_count: `{prop_name}_distinct`"
)
)
elif prop_type == "LIST":
with_clauses.append(
(
f"min(size(n.`{prop_name}`)) AS `{prop_name}_size_min`, "
f"max(size(n.`{prop_name}`)) AS `{prop_name}_size_max`"
)
)
return_clauses.append(
(
f"min_size: `{prop_name}_size_min`, "
f"max_size: `{prop_name}_size_max`"
)
)
elif prop_type in ["BOOLEAN", "POINT", "DURATION"]:
continue
output_dict[prop_name] = "{" + return_clauses.pop() + "}"
with_clause = "WITH " + ",\n ".join(with_clauses)
return_clause = (
"RETURN {"
+ ", ".join(f"`{k}`: {v}" for k, v in output_dict.items())
+ "} AS output"
)
# Combine all parts of the Cypher query
cypher_query = "\n".join([match_clause, with_clause, return_clause])
return cypher_query

View File

@@ -0,0 +1,426 @@
import json
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple, Union
from langchain_core._api.deprecation import deprecated
class NeptuneQueryException(Exception):
"""Exception for the Neptune queries."""
def __init__(self, exception: Union[str, Dict]):
if isinstance(exception, dict):
self.message = exception["message"] if "message" in exception else "unknown"
self.details = exception["details"] if "details" in exception else "unknown"
else:
self.message = exception
self.details = "unknown"
def get_message(self) -> str:
return self.message
def get_details(self) -> Any:
return self.details
class BaseNeptuneGraph(ABC):
"""Abstract base class for Neptune."""
@property
def get_schema(self) -> str:
"""Return the schema of the Neptune database"""
return self.schema
@abstractmethod
def query(self, query: str, params: dict = {}) -> dict:
raise NotImplementedError()
@abstractmethod
def _get_summary(self) -> Dict:
raise NotImplementedError()
def _get_labels(self) -> Tuple[List[str], List[str]]:
"""Get node and edge labels from the Neptune statistics summary"""
summary = self._get_summary()
n_labels = summary["nodeLabels"]
e_labels = summary["edgeLabels"]
return n_labels, e_labels
def _get_triples(self, e_labels: List[str]) -> List[str]:
triple_query = """
MATCH (a)-[e:`{e_label}`]->(b)
WITH a,e,b LIMIT 3000
RETURN DISTINCT labels(a) AS from, type(e) AS edge, labels(b) AS to
LIMIT 10
"""
triple_template = "(:`{a}`)-[:`{e}`]->(:`{b}`)"
triple_schema = []
for label in e_labels:
q = triple_query.format(e_label=label)
data = self.query(q)
for d in data:
triple = triple_template.format(
a=d["from"][0], e=d["edge"], b=d["to"][0]
)
triple_schema.append(triple)
return triple_schema
def _get_node_properties(self, n_labels: List[str], types: Dict) -> List:
node_properties_query = """
MATCH (a:`{n_label}`)
RETURN properties(a) AS props
LIMIT 100
"""
node_properties = []
for label in n_labels:
q = node_properties_query.format(n_label=label)
data = {"label": label, "properties": self.query(q)}
s = set({})
for p in data["properties"]:
for k, v in p["props"].items():
s.add((k, types[type(v).__name__]))
np = {
"properties": [{"property": k, "type": v} for k, v in s],
"labels": label,
}
node_properties.append(np)
return node_properties
def _get_edge_properties(self, e_labels: List[str], types: Dict[str, Any]) -> List:
edge_properties_query = """
MATCH ()-[e:`{e_label}`]->()
RETURN properties(e) AS props
LIMIT 100
"""
edge_properties = []
for label in e_labels:
q = edge_properties_query.format(e_label=label)
data = {"label": label, "properties": self.query(q)}
s = set({})
for p in data["properties"]:
for k, v in p["props"].items():
s.add((k, types[type(v).__name__]))
ep = {
"type": label,
"properties": [{"property": k, "type": v} for k, v in s],
}
edge_properties.append(ep)
return edge_properties
def _refresh_schema(self) -> None:
"""
Refreshes the Neptune graph schema information.
"""
types = {
"str": "STRING",
"float": "DOUBLE",
"int": "INTEGER",
"list": "LIST",
"dict": "MAP",
"bool": "BOOLEAN",
}
n_labels, e_labels = self._get_labels()
triple_schema = self._get_triples(e_labels)
node_properties = self._get_node_properties(n_labels, types)
edge_properties = self._get_edge_properties(e_labels, types)
self.schema = f"""
Node properties are the following:
{node_properties}
Relationship properties are the following:
{edge_properties}
The relationships are the following:
{triple_schema}
"""
@deprecated(
since="0.3.15",
removal="1.0",
alternative_import="langchain_aws.NeptuneAnalyticsGraph",
)
class NeptuneAnalyticsGraph(BaseNeptuneGraph):
"""Neptune Analytics wrapper for graph operations.
Parameters:
client: optional boto3 Neptune client
credentials_profile_name: optional AWS profile name
region_name: optional AWS region, e.g., us-west-2
graph_identifier: the graph identifier for a Neptune Analytics graph
Example:
.. code-block:: python
graph = NeptuneAnalyticsGraph(
graph_identifier='<my-graph-id>'
)
*Security note*: Make sure that the database connection uses credentials
that are narrowly-scoped to only include necessary permissions.
Failure to do so may result in data corruption or loss, since the calling
code may attempt commands that would result in deletion, mutation
of data if appropriately prompted or reading sensitive data if such
data is present in the database.
The best way to guard against such negative outcomes is to (as appropriate)
limit the permissions granted to the credentials used with this tool.
See https://python.langchain.com/docs/security for more information.
"""
def __init__(
self,
graph_identifier: str,
client: Any = None,
credentials_profile_name: Optional[str] = None,
region_name: Optional[str] = None,
) -> None:
"""Create a new Neptune Analytics graph wrapper instance."""
try:
if client is not None:
self.client = client
else:
import boto3
if credentials_profile_name is not None:
session = boto3.Session(profile_name=credentials_profile_name)
else:
# use default credentials
session = boto3.Session()
self.graph_identifier = graph_identifier
if region_name:
self.client = session.client(
"neptune-graph", region_name=region_name
)
else:
self.client = session.client("neptune-graph")
except ImportError:
raise ImportError(
"Could not import boto3 python package. "
"Please install it with `pip install boto3`."
)
except Exception as e:
if type(e).__name__ == "UnknownServiceError":
raise ImportError(
"NeptuneGraph requires a boto3 version 1.34.40 or greater."
"Please install it with `pip install -U boto3`."
) from e
else:
raise ValueError(
"Could not load credentials to authenticate with AWS client. "
"Please check that credentials in the specified "
"profile name are valid."
) from e
try:
self._refresh_schema()
except Exception as e:
raise NeptuneQueryException(
{
"message": "Could not get schema for Neptune database",
"detail": str(e),
}
)
def query(self, query: str, params: dict = {}) -> Dict[str, Any]:
"""Query Neptune database."""
try:
resp = self.client.execute_query(
graphIdentifier=self.graph_identifier,
queryString=query,
parameters=params,
language="OPEN_CYPHER",
)
return json.loads(resp["payload"].read().decode("UTF-8"))["results"]
except Exception as e:
raise NeptuneQueryException(
{
"message": "An error occurred while executing the query.",
"details": str(e),
}
)
def _get_summary(self) -> Dict:
try:
response = self.client.get_graph_summary(
graphIdentifier=self.graph_identifier, mode="detailed"
)
except Exception as e:
raise NeptuneQueryException(
{
"message": ("Summary API error occurred on Neptune Analytics"),
"details": str(e),
}
)
try:
summary = response["graphSummary"]
except Exception:
raise NeptuneQueryException(
{
"message": "Summary API did not return a valid response.",
"details": response.content.decode(),
}
)
else:
return summary
@deprecated(
since="0.3.15",
removal="1.0",
alternative_import="langchain_aws.NeptuneGraph",
)
class NeptuneGraph(BaseNeptuneGraph):
"""Neptune wrapper for graph operations.
Parameters:
host: endpoint for the database instance
port: port number for the database instance, default is 8182
use_https: whether to use secure connection, default is True
client: optional boto3 Neptune client
credentials_profile_name: optional AWS profile name
region_name: optional AWS region, e.g., us-west-2
sign: optional, whether to sign the request payload, default is True
Example:
.. code-block:: python
graph = NeptuneGraph(
host='<my-cluster>',
port=8182
)
*Security note*: Make sure that the database connection uses credentials
that are narrowly-scoped to only include necessary permissions.
Failure to do so may result in data corruption or loss, since the calling
code may attempt commands that would result in deletion, mutation
of data if appropriately prompted or reading sensitive data if such
data is present in the database.
The best way to guard against such negative outcomes is to (as appropriate)
limit the permissions granted to the credentials used with this tool.
See https://python.langchain.com/docs/security for more information.
"""
def __init__(
self,
host: str,
port: int = 8182,
use_https: bool = True,
client: Any = None,
credentials_profile_name: Optional[str] = None,
region_name: Optional[str] = None,
sign: bool = True,
) -> None:
"""Create a new Neptune graph wrapper instance."""
try:
if client is not None:
self.client = client
else:
import boto3
if credentials_profile_name is not None:
session = boto3.Session(profile_name=credentials_profile_name)
else:
# use default credentials
session = boto3.Session()
client_params = {}
if region_name:
client_params["region_name"] = region_name
protocol = "https" if use_https else "http"
client_params["endpoint_url"] = f"{protocol}://{host}:{port}"
if sign:
self.client = session.client("neptunedata", **client_params)
else:
from botocore import UNSIGNED
from botocore.config import Config
self.client = session.client(
"neptunedata",
**client_params,
config=Config(signature_version=UNSIGNED),
)
except ImportError:
raise ImportError(
"Could not import boto3 python package. "
"Please install it with `pip install boto3`."
)
except Exception as e:
if type(e).__name__ == "UnknownServiceError":
raise ImportError(
"NeptuneGraph requires a boto3 version 1.28.38 or greater."
"Please install it with `pip install -U boto3`."
) from e
else:
raise ValueError(
"Could not load credentials to authenticate with AWS client. "
"Please check that credentials in the specified "
"profile name are valid."
) from e
try:
self._refresh_schema()
except Exception as e:
raise NeptuneQueryException(
{
"message": "Could not get schema for Neptune database",
"detail": str(e),
}
)
def query(self, query: str, params: dict = {}) -> Dict[str, Any]:
"""Query Neptune database."""
try:
return self.client.execute_open_cypher_query(openCypherQuery=query)[
"results"
]
except Exception as e:
raise NeptuneQueryException(
{
"message": "An error occurred while executing the query.",
"details": str(e),
}
)
def _get_summary(self) -> Dict:
try:
response = self.client.get_propertygraph_summary()
except Exception as e:
raise NeptuneQueryException(
{
"message": (
"Summary API is not available for this instance of Neptune,"
"ensure the engine version is >=1.2.1.0"
),
"details": str(e),
}
)
try:
summary = response["payload"]["graphSummary"]
except Exception:
raise NeptuneQueryException(
{
"message": "Summary API did not return a valid response.",
"details": response.content.decode(),
}
)
else:
return summary

View File

@@ -0,0 +1,302 @@
import json
from types import SimpleNamespace
from typing import Any, Dict, Optional, Sequence
import requests
from langchain_core._api.deprecation import deprecated
# Query to find OWL datatype properties
DTPROP_QUERY = """
SELECT DISTINCT ?elem
WHERE {
?elem a owl:DatatypeProperty .
}
"""
# Query to find OWL object properties
OPROP_QUERY = """
SELECT DISTINCT ?elem
WHERE {
?elem a owl:ObjectProperty .
}
"""
ELEM_TYPES = {
"classes": None,
"rels": None,
"dtprops": DTPROP_QUERY,
"oprops": OPROP_QUERY,
}
@deprecated(
since="0.3.15",
removal="1.0",
alternative_import="langchain_aws.NeptuneRdfGraph",
)
class NeptuneRdfGraph:
"""Neptune wrapper for RDF graph operations.
Args:
host: endpoint for the database instance
port: port number for the database instance, default is 8182
use_iam_auth: boolean indicating IAM auth is enabled in Neptune cluster
use_https: whether to use secure connection, default is True
client: optional boto3 Neptune client
credentials_profile_name: optional AWS profile name
region_name: optional AWS region, e.g., us-west-2
service: optional service name, default is neptunedata
sign: optional, whether to sign the request payload, default is True
Example:
.. code-block:: python
graph = NeptuneRdfGraph(
host='<SPARQL host'>,
port=<SPARQL port>
)
schema = graph.get_schema()
OR
graph = NeptuneRdfGraph(
host='<SPARQL host'>,
port=<SPARQL port>
)
schema_elem = graph.get_schema_elements()
#... change schema_elements ...
graph.load_schema(schema_elem)
*Security note*: Make sure that the database connection uses credentials
that are narrowly-scoped to only include necessary permissions.
Failure to do so may result in data corruption or loss, since the calling
code may attempt commands that would result in deletion, mutation
of data if appropriately prompted or reading sensitive data if such
data is present in the database.
The best way to guard against such negative outcomes is to (as appropriate)
limit the permissions granted to the credentials used with this tool.
See https://python.langchain.com/docs/security for more information.
"""
def __init__(
self,
host: str,
port: int = 8182,
use_https: bool = True,
use_iam_auth: bool = False,
client: Any = None,
credentials_profile_name: Optional[str] = None,
region_name: Optional[str] = None,
service: str = "neptunedata",
sign: bool = True,
) -> None:
self.use_iam_auth = use_iam_auth
self.region_name = region_name
self.query_endpoint = f"https://{host}:{port}/sparql"
try:
if client is not None:
self.client = client
else:
import boto3
if credentials_profile_name is not None:
self.session = boto3.Session(profile_name=credentials_profile_name)
else:
# use default credentials
self.session = boto3.Session()
client_params = {}
if region_name:
client_params["region_name"] = region_name
protocol = "https" if use_https else "http"
client_params["endpoint_url"] = f"{protocol}://{host}:{port}"
if sign:
self.client = self.session.client(service, **client_params)
else:
from botocore import UNSIGNED
from botocore.config import Config
self.client = self.session.client(
service,
**client_params,
config=Config(signature_version=UNSIGNED),
)
except ImportError:
raise ImportError(
"Could not import boto3 python package. "
"Please install it with `pip install boto3`."
)
except Exception as e:
if type(e).__name__ == "UnknownServiceError":
raise ImportError(
"NeptuneGraph requires a boto3 version 1.28.38 or greater."
"Please install it with `pip install -U boto3`."
) from e
else:
raise ValueError(
"Could not load credentials to authenticate with AWS client. "
"Please check that credentials in the specified "
"profile name are valid."
) from e
# Set schema
self.schema = ""
self.schema_elements: Dict[str, Any] = {}
self._refresh_schema()
@property
def get_schema(self) -> str:
"""
Returns the schema of the graph database.
"""
return self.schema
@property
def get_schema_elements(self) -> Dict[str, Any]:
return self.schema_elements
def get_summary(self) -> Dict[str, Any]:
"""
Obtain Neptune statistical summary of classes and predicates in the graph.
"""
return self.client.get_rdf_graph_summary(mode="detailed")
def query(
self,
query: str,
) -> Dict[str, Any]:
"""
Run Neptune query.
"""
request_data = {"query": query}
data = request_data
request_hdr = None
if self.use_iam_auth:
credentials = self.session.get_credentials()
credentials = credentials.get_frozen_credentials()
access_key = credentials.access_key
secret_key = credentials.secret_key
service = "neptune-db"
session_token = credentials.token
params = None
creds = SimpleNamespace(
access_key=access_key,
secret_key=secret_key,
token=session_token,
region=self.region_name,
)
from botocore.awsrequest import AWSRequest
request = AWSRequest(
method="POST", url=self.query_endpoint, data=data, params=params
)
from botocore.auth import SigV4Auth
SigV4Auth(creds, service, self.region_name).add_auth(request)
request.headers["Content-Type"] = "application/x-www-form-urlencoded"
request_hdr = request.headers
else:
request_hdr = {}
request_hdr["Content-Type"] = "application/x-www-form-urlencoded"
queryres = requests.request(
method="POST", url=self.query_endpoint, headers=request_hdr, data=data
)
json_resp = json.loads(queryres.text)
return json_resp
def load_schema(self, schema_elements: Dict[str, Any]) -> None:
"""
Generates and sets schema from schema_elements. Helpful in
cases where introspected schema needs pruning.
"""
elem_str = {}
for elem in ELEM_TYPES:
res_list = []
for elem_rec in schema_elements[elem]:
uri = elem_rec["uri"]
local = elem_rec["local"]
res_str = f"<{uri}> ({local})"
res_list.append(res_str)
elem_str[elem] = ", ".join(res_list)
self.schema = (
"In the following, each IRI is followed by the local name and "
"optionally its description in parentheses. \n"
"The graph supports the following node types:\n"
f"{elem_str['classes']}\n"
"The graph supports the following relationships:\n"
f"{elem_str['rels']}\n"
"The graph supports the following OWL object properties:\n"
f"{elem_str['dtprops']}\n"
"The graph supports the following OWL data properties:\n"
f"{elem_str['oprops']}"
)
def _get_local_name(self, iri: str) -> Sequence[str]:
"""
Split IRI into prefix and local
"""
if "#" in iri:
tokens = iri.split("#")
return [f"{tokens[0]}#", tokens[-1]]
elif "/" in iri:
tokens = iri.split("/")
return [f"{'/'.join(tokens[0 : len(tokens) - 1])}/", tokens[-1]]
else:
raise ValueError(f"Unexpected IRI '{iri}', contains neither '#' nor '/'.")
def _refresh_schema(self) -> None:
"""
Query Neptune to introspect schema.
"""
self.schema_elements["distinct_prefixes"] = {}
# get summary and build list of classes and rels
summary = self.get_summary()
reslist = []
for c in summary["payload"]["graphSummary"]["classes"]:
uri = c
tokens = self._get_local_name(uri)
elem_record = {"uri": uri, "local": tokens[1]}
reslist.append(elem_record)
if tokens[0] not in self.schema_elements["distinct_prefixes"]:
self.schema_elements["distinct_prefixes"][tokens[0]] = "y"
self.schema_elements["classes"] = reslist
reslist = []
for r in summary["payload"]["graphSummary"]["predicates"]:
for p in r:
uri = p
tokens = self._get_local_name(uri)
elem_record = {"uri": uri, "local": tokens[1]}
reslist.append(elem_record)
if tokens[0] not in self.schema_elements["distinct_prefixes"]:
self.schema_elements["distinct_prefixes"][tokens[0]] = "y"
self.schema_elements["rels"] = reslist
# get dtprops and oprops too
for elem in ELEM_TYPES:
q = ELEM_TYPES.get(elem)
if not q:
continue
items = self.query(q)
reslist = []
for r in items["results"]["bindings"]:
uri = r["elem"]["value"]
tokens = self._get_local_name(uri)
elem_record = {"uri": uri, "local": tokens[1]}
reslist.append(elem_record)
if tokens[0] not in self.schema_elements["distinct_prefixes"]:
self.schema_elements["distinct_prefixes"][tokens[0]] = "y"
self.schema_elements[elem] = reslist
self.load_schema(self.schema_elements)

View File

@@ -0,0 +1,218 @@
"""Networkx wrapper for graph operations."""
from __future__ import annotations
from typing import Any, List, NamedTuple, Optional, Tuple
KG_TRIPLE_DELIMITER = "<|>"
class KnowledgeTriple(NamedTuple):
"""Knowledge triple in the graph."""
subject: str
predicate: str
object_: str
@classmethod
def from_string(cls, triple_string: str) -> "KnowledgeTriple":
"""Create a KnowledgeTriple from a string."""
subject, predicate, object_ = triple_string.strip().split(", ")
subject = subject[1:]
object_ = object_[:-1]
return cls(subject, predicate, object_)
def parse_triples(knowledge_str: str) -> List[KnowledgeTriple]:
"""Parse knowledge triples from the knowledge string."""
knowledge_str = knowledge_str.strip()
if not knowledge_str or knowledge_str == "NONE":
return []
triple_strs = knowledge_str.split(KG_TRIPLE_DELIMITER)
results = []
for triple_str in triple_strs:
try:
kg_triple = KnowledgeTriple.from_string(triple_str)
except ValueError:
continue
results.append(kg_triple)
return results
def get_entities(entity_str: str) -> List[str]:
"""Extract entities from entity string."""
if entity_str.strip() == "NONE":
return []
else:
return [w.strip() for w in entity_str.split(",")]
class NetworkxEntityGraph:
"""Networkx wrapper for entity graph operations.
*Security note*: Make sure that the database connection uses credentials
that are narrowly-scoped to only include necessary permissions.
Failure to do so may result in data corruption or loss, since the calling
code may attempt commands that would result in deletion, mutation
of data if appropriately prompted or reading sensitive data if such
data is present in the database.
The best way to guard against such negative outcomes is to (as appropriate)
limit the permissions granted to the credentials used with this tool.
See https://python.langchain.com/docs/security for more information.
"""
def __init__(self, graph: Optional[Any] = None) -> None:
"""Create a new graph."""
try:
import networkx as nx
except ImportError:
raise ImportError(
"Could not import networkx python package. "
"Please install it with `pip install networkx`."
)
if graph is not None:
if not isinstance(graph, nx.DiGraph):
raise ValueError("Passed in graph is not of correct shape")
self._graph = graph
else:
self._graph = nx.DiGraph()
@classmethod
def from_gml(cls, gml_path: str) -> NetworkxEntityGraph:
try:
import networkx as nx
except ImportError:
raise ImportError(
"Could not import networkx python package. "
"Please install it with `pip install networkx`."
)
graph = nx.read_gml(gml_path)
return cls(graph)
def add_triple(self, knowledge_triple: KnowledgeTriple) -> None:
"""Add a triple to the graph."""
# Creates nodes if they don't exist
# Overwrites existing edges
if not self._graph.has_node(knowledge_triple.subject):
self._graph.add_node(knowledge_triple.subject)
if not self._graph.has_node(knowledge_triple.object_):
self._graph.add_node(knowledge_triple.object_)
self._graph.add_edge(
knowledge_triple.subject,
knowledge_triple.object_,
relation=knowledge_triple.predicate,
)
def delete_triple(self, knowledge_triple: KnowledgeTriple) -> None:
"""Delete a triple from the graph."""
if self._graph.has_edge(knowledge_triple.subject, knowledge_triple.object_):
self._graph.remove_edge(knowledge_triple.subject, knowledge_triple.object_)
def get_triples(self) -> List[Tuple[str, str, str]]:
"""Get all triples in the graph."""
return [(u, v, d["relation"]) for u, v, d in self._graph.edges(data=True)]
def get_entity_knowledge(self, entity: str, depth: int = 1) -> List[str]:
"""Get information about an entity."""
import networkx as nx
# TODO: Have more information-specific retrieval methods
if not self._graph.has_node(entity):
return []
results = []
for src, sink in nx.dfs_edges(self._graph, entity, depth_limit=depth):
relation = self._graph[src][sink]["relation"]
results.append(f"{src} {relation} {sink}")
return results
def write_to_gml(self, path: str) -> None:
import networkx as nx
nx.write_gml(self._graph, path)
def clear(self) -> None:
"""Clear the graph."""
self._graph.clear()
def clear_edges(self) -> None:
"""Clear the graph edges."""
self._graph.clear_edges()
def add_node(self, node: str) -> None:
"""Add node in the graph."""
self._graph.add_node(node)
def remove_node(self, node: str) -> None:
"""Remove node from the graph."""
if self._graph.has_node(node):
self._graph.remove_node(node)
def has_node(self, node: str) -> bool:
"""Return if graph has the given node."""
return self._graph.has_node(node)
def remove_edge(self, source_node: str, destination_node: str) -> None:
"""Remove edge from the graph."""
self._graph.remove_edge(source_node, destination_node)
def has_edge(self, source_node: str, destination_node: str) -> bool:
"""Return if graph has an edge between the given nodes."""
if self._graph.has_node(source_node) and self._graph.has_node(destination_node):
return self._graph.has_edge(source_node, destination_node)
else:
return False
def get_neighbors(self, node: str) -> List[str]:
"""Return the neighbor nodes of the given node."""
return self._graph.neighbors(node)
def get_number_of_nodes(self) -> int:
"""Get number of nodes in the graph."""
return self._graph.number_of_nodes()
def get_topological_sort(self) -> List[str]:
"""Get a list of entity names in the graph sorted by causal dependence."""
import networkx as nx
return list(nx.topological_sort(self._graph))
def draw_graphviz(self, **kwargs: Any) -> None:
"""
Provides better drawing
Usage in a jupyter notebook:
>>> from IPython.display import SVG
>>> self.draw_graphviz_svg(layout="dot", filename="web.svg")
>>> SVG('web.svg')
"""
from networkx.drawing.nx_agraph import to_agraph
try:
import pygraphviz # noqa: F401
except ImportError as e:
if e.name == "_graphviz":
"""
>>> e.msg # pygraphviz throws this error
ImportError: libcgraph.so.6: cannot open shared object file
"""
raise ImportError(
"Could not import graphviz debian package. "
"Please install it with:"
"`sudo apt-get update`"
"`sudo apt-get install graphviz graphviz-dev`"
)
else:
raise ImportError(
"Could not import pygraphviz python package. "
"Please install it with:"
"`pip install pygraphviz`."
)
graph = to_agraph(self._graph) # --> pygraphviz.agraph.AGraph
# pygraphviz.github.io/documentation/stable/tutorial.html#layout-and-drawing
graph.layout(prog=kwargs.get("prog", "dot"))
graph.draw(kwargs.get("path", "graph.svg"))

View File

@@ -0,0 +1,216 @@
from __future__ import annotations
import os
from typing import (
TYPE_CHECKING,
List,
Optional,
Union,
)
if TYPE_CHECKING:
import rdflib
class OntotextGraphDBGraph:
"""Ontotext GraphDB https://graphdb.ontotext.com/ wrapper for graph operations.
*Security note*: Make sure that the database connection uses credentials
that are narrowly-scoped to only include necessary permissions.
Failure to do so may result in data corruption or loss, since the calling
code may attempt commands that would result in deletion, mutation
of data if appropriately prompted or reading sensitive data if such
data is present in the database.
The best way to guard against such negative outcomes is to (as appropriate)
limit the permissions granted to the credentials used with this tool.
See https://python.langchain.com/docs/security for more information.
"""
def __init__(
self,
query_endpoint: str,
query_ontology: Optional[str] = None,
local_file: Optional[str] = None,
local_file_format: Optional[str] = None,
) -> None:
"""
Set up the GraphDB wrapper
:param query_endpoint: SPARQL endpoint for queries, read access
If GraphDB is secured,
set the environment variables 'GRAPHDB_USERNAME' and 'GRAPHDB_PASSWORD'.
:param query_ontology: a `CONSTRUCT` query that is executed
on the SPARQL endpoint and returns the KG schema statements
Example:
'CONSTRUCT {?s ?p ?o} FROM <https://example.com/ontology/> WHERE {?s ?p ?o}'
Currently, DESCRIBE queries like
'PREFIX onto: <https://example.com/ontology/>
PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
DESCRIBE ?term WHERE {
?term rdfs:isDefinedBy onto:
}'
are not supported, because DESCRIBE returns
the Symmetric Concise Bounded Description (SCBD),
i.e. also the incoming class links.
In case of large graphs with a million of instances, this is not efficient.
Check https://github.com/eclipse-rdf4j/rdf4j/issues/4857
:param local_file: a local RDF ontology file.
Supported RDF formats:
Turtle, RDF/XML, JSON-LD, N-Triples, Notation-3, Trig, Trix, N-Quads.
If the rdf format can't be determined from the file extension,
pass explicitly the rdf format in `local_file_format` param.
:param local_file_format: Used if the rdf format can't be determined
from the local file extension.
One of "json-ld", "xml", "n3", "turtle", "nt", "trig", "nquads", "trix"
Either `query_ontology` or `local_file` should be passed.
"""
if query_ontology and local_file:
raise ValueError("Both file and query provided. Only one is allowed.")
if not query_ontology and not local_file:
raise ValueError("Neither file nor query provided. One is required.")
try:
import rdflib
from rdflib.plugins.stores import sparqlstore
except ImportError:
raise ImportError(
"Could not import rdflib python package. "
"Please install it with `pip install rdflib`."
)
auth = self._get_auth()
store = sparqlstore.SPARQLStore(auth=auth)
store.open(query_endpoint)
self.graph = rdflib.Graph(store, identifier=None, bind_namespaces="none")
self._check_connectivity()
ontology_schema_graph: "rdflib.Graph"
if local_file:
ontology_schema_graph = self._load_ontology_schema_from_file(
local_file,
local_file_format,
)
else:
self._validate_user_query(query_ontology) # type: ignore[arg-type]
ontology_schema_graph = self._load_ontology_schema_with_query(
query_ontology # type: ignore[arg-type]
)
self.schema = ontology_schema_graph.serialize(format="turtle")
@staticmethod
def _get_auth() -> Union[tuple, None]:
"""
Returns the basic authentication configuration
"""
username = os.environ.get("GRAPHDB_USERNAME", None)
password = os.environ.get("GRAPHDB_PASSWORD", None)
if username:
if not password:
raise ValueError(
"Environment variable 'GRAPHDB_USERNAME' is set, "
"but 'GRAPHDB_PASSWORD' is not set."
)
else:
return username, password
return None
def _check_connectivity(self) -> None:
"""
Executes a simple `ASK` query to check connectivity
"""
try:
self.graph.query("ASK { ?s ?p ?o }")
except ValueError:
raise ValueError(
"Could not query the provided endpoint. "
"Please, check, if the value of the provided "
"query_endpoint points to the right repository. "
"If GraphDB is secured, please, "
"make sure that the environment variables "
"'GRAPHDB_USERNAME' and 'GRAPHDB_PASSWORD' are set."
)
@staticmethod
def _load_ontology_schema_from_file(
local_file: str, local_file_format: Optional[str] = None
) -> "rdflib.ConjunctiveGraph":
"""
Parse the ontology schema statements from the provided file
"""
import rdflib
if not os.path.exists(local_file):
raise FileNotFoundError(f"File {local_file} does not exist.")
if not os.access(local_file, os.R_OK):
raise PermissionError(f"Read permission for {local_file} is restricted")
graph = rdflib.ConjunctiveGraph()
try:
graph.parse(local_file, format=local_file_format)
except Exception as e:
raise ValueError(f"Invalid file format for {local_file} : ", e)
return graph
@staticmethod
def _validate_user_query(query_ontology: str) -> None:
"""
Validate the query is a valid SPARQL CONSTRUCT query
"""
from pyparsing import ParseException
from rdflib.plugins.sparql import prepareQuery
if not isinstance(query_ontology, str):
raise TypeError("Ontology query must be provided as string.")
try:
parsed_query = prepareQuery(query_ontology)
except ParseException as e:
raise ValueError("Ontology query is not a valid SPARQL query.", e)
if parsed_query.algebra.name != "ConstructQuery":
raise ValueError(
"Invalid query type. Only CONSTRUCT queries are supported."
)
def _load_ontology_schema_with_query(self, query: str) -> "rdflib.Graph":
"""
Execute the query for collecting the ontology schema statements
"""
from rdflib.exceptions import ParserError
try:
results = self.graph.query(query)
except ParserError as e:
raise ValueError(f"Generated SPARQL statement is invalid\n{e}")
if not results.graph:
raise ValueError("Missing graph in results.")
return results.graph
@property
def get_schema(self) -> str:
"""
Returns the schema of the graph database in turtle format
"""
return self.schema
def query(
self,
query: str,
) -> List[rdflib.query.ResultRow]:
"""
Query the graph.
"""
from rdflib.query import ResultRow
res = self.graph.query(query)
return [r for r in res if isinstance(r, ResultRow)]

View File

@@ -0,0 +1,307 @@
from __future__ import annotations
from typing import (
TYPE_CHECKING,
Dict,
List,
Optional,
)
if TYPE_CHECKING:
import rdflib
prefixes = {
"owl": """PREFIX owl: <http://www.w3.org/2002/07/owl#>\n""",
"rdf": """PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>\n""",
"rdfs": """PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>\n""",
"xsd": """PREFIX xsd: <http://www.w3.org/2001/XMLSchema#>\n""",
}
cls_query_rdf = prefixes["rdfs"] + (
"""SELECT DISTINCT ?cls ?com\n"""
"""WHERE { \n"""
""" ?instance a ?cls . \n"""
""" OPTIONAL { ?cls rdfs:comment ?com } \n"""
"""}"""
)
cls_query_rdfs = prefixes["rdfs"] + (
"""SELECT DISTINCT ?cls ?com\n"""
"""WHERE { \n"""
""" ?instance a/rdfs:subClassOf* ?cls . \n"""
""" OPTIONAL { ?cls rdfs:comment ?com } \n"""
"""}"""
)
cls_query_owl = prefixes["rdfs"] + (
"""SELECT DISTINCT ?cls ?com\n"""
"""WHERE { \n"""
""" ?instance a/rdfs:subClassOf* ?cls . \n"""
""" FILTER (isIRI(?cls)) . \n"""
""" OPTIONAL { ?cls rdfs:comment ?com } \n"""
"""}"""
)
rel_query_rdf = prefixes["rdfs"] + (
"""SELECT DISTINCT ?rel ?com\n"""
"""WHERE { \n"""
""" ?subj ?rel ?obj . \n"""
""" OPTIONAL { ?rel rdfs:comment ?com } \n"""
"""}"""
)
rel_query_rdfs = (
prefixes["rdf"]
+ prefixes["rdfs"]
+ (
"""SELECT DISTINCT ?rel ?com\n"""
"""WHERE { \n"""
""" ?rel a/rdfs:subPropertyOf* rdf:Property . \n"""
""" OPTIONAL { ?rel rdfs:comment ?com } \n"""
"""}"""
)
)
op_query_owl = (
prefixes["rdfs"]
+ prefixes["owl"]
+ (
"""SELECT DISTINCT ?op ?com\n"""
"""WHERE { \n"""
""" ?op a/rdfs:subPropertyOf* owl:ObjectProperty . \n"""
""" OPTIONAL { ?op rdfs:comment ?com } \n"""
"""}"""
)
)
dp_query_owl = (
prefixes["rdfs"]
+ prefixes["owl"]
+ (
"""SELECT DISTINCT ?dp ?com\n"""
"""WHERE { \n"""
""" ?dp a/rdfs:subPropertyOf* owl:DatatypeProperty . \n"""
""" OPTIONAL { ?dp rdfs:comment ?com } \n"""
"""}"""
)
)
class RdfGraph:
"""RDFlib wrapper for graph operations.
Modes:
* local: Local file - can be queried and changed
* online: Online file - can only be queried, changes can be stored locally
* store: Triple store - can be queried and changed if update_endpoint available
Together with a source file, the serialization should be specified.
*Security note*: Make sure that the database connection uses credentials
that are narrowly-scoped to only include necessary permissions.
Failure to do so may result in data corruption or loss, since the calling
code may attempt commands that would result in deletion, mutation
of data if appropriately prompted or reading sensitive data if such
data is present in the database.
The best way to guard against such negative outcomes is to (as appropriate)
limit the permissions granted to the credentials used with this tool.
See https://python.langchain.com/docs/security for more information.
"""
def __init__(
self,
source_file: Optional[str] = None,
serialization: Optional[str] = "ttl",
query_endpoint: Optional[str] = None,
update_endpoint: Optional[str] = None,
standard: Optional[str] = "rdf",
local_copy: Optional[str] = None,
graph_kwargs: Optional[Dict] = None,
store_kwargs: Optional[Dict] = None,
) -> None:
"""
Set up the RDFlib graph
:param source_file: either a path for a local file or a URL
:param serialization: serialization of the input
:param query_endpoint: SPARQL endpoint for queries, read access
:param update_endpoint: SPARQL endpoint for UPDATE queries, write access
:param standard: RDF, RDFS, or OWL
:param local_copy: new local copy for storing changes
:param graph_kwargs: Additional rdflib.Graph specific kwargs
that will be used to initialize it,
if query_endpoint is provided.
:param store_kwargs: Additional sparqlstore.SPARQLStore specific kwargs
that will be used to initialize it,
if query_endpoint is provided.
"""
self.source_file = source_file
self.serialization = serialization
self.query_endpoint = query_endpoint
self.update_endpoint = update_endpoint
self.standard = standard
self.local_copy = local_copy
try:
import rdflib
from rdflib.plugins.stores import sparqlstore
except ImportError:
raise ImportError(
"Could not import rdflib python package. "
"Please install it with `pip install rdflib`."
)
if self.standard not in (supported_standards := ("rdf", "rdfs", "owl")):
raise ValueError(
f"Invalid standard. Supported standards are: {supported_standards}."
)
if (
not source_file
and not query_endpoint
or source_file
and (query_endpoint or update_endpoint)
):
raise ValueError(
"Could not unambiguously initialize the graph wrapper. "
"Specify either a file (local or online) via the source_file "
"or a triple store via the endpoints."
)
if source_file:
if source_file.startswith("http"):
self.mode = "online"
else:
self.mode = "local"
if self.local_copy is None:
self.local_copy = self.source_file
self.graph = rdflib.Graph()
self.graph.parse(source_file, format=self.serialization)
if query_endpoint:
store_kwargs = store_kwargs or {}
self.mode = "store"
if not update_endpoint:
self._store = sparqlstore.SPARQLStore(**store_kwargs)
self._store.open(query_endpoint)
else:
self._store = sparqlstore.SPARQLUpdateStore(**store_kwargs)
self._store.open((query_endpoint, update_endpoint))
graph_kwargs = graph_kwargs or {}
self.graph = rdflib.Graph(self._store, **graph_kwargs)
# Verify that the graph was loaded
if not len(self.graph):
raise AssertionError("The graph is empty.")
# Set schema
self.schema = ""
self.load_schema()
@property
def get_schema(self) -> str:
"""
Returns the schema of the graph database.
"""
return self.schema
def query(
self,
query: str,
) -> List[rdflib.query.ResultRow]:
"""
Query the graph.
"""
from rdflib.exceptions import ParserError
from rdflib.query import ResultRow
try:
res = self.graph.query(query)
except ParserError as e:
raise ValueError(f"Generated SPARQL statement is invalid\n{e}")
return [r for r in res if isinstance(r, ResultRow)]
def update(
self,
query: str,
) -> None:
"""
Update the graph.
"""
from rdflib.exceptions import ParserError
try:
self.graph.update(query)
except ParserError as e:
raise ValueError(f"Generated SPARQL statement is invalid\n{e}")
if self.local_copy:
self.graph.serialize(
destination=self.local_copy, format=self.local_copy.split(".")[-1]
)
else:
raise ValueError("No target file specified for saving the updated file.")
@staticmethod
def _get_local_name(iri: str) -> str:
if "#" in iri:
local_name = iri.split("#")[-1]
elif "/" in iri:
local_name = iri.split("/")[-1]
else:
raise ValueError(f"Unexpected IRI '{iri}', contains neither '#' nor '/'.")
return local_name
def _res_to_str(self, res: rdflib.query.ResultRow, var: str) -> str:
return (
"<"
+ str(res[var])
+ "> ("
+ self._get_local_name(res[var])
+ ", "
+ str(res["com"])
+ ")"
)
def load_schema(self) -> None:
"""
Load the graph schema information.
"""
def _rdf_s_schema(
classes: List[rdflib.query.ResultRow],
relationships: List[rdflib.query.ResultRow],
) -> str:
return (
f"In the following, each IRI is followed by the local name and "
f"optionally its description in parentheses. \n"
f"The RDF graph supports the following node types:\n"
f"{', '.join([self._res_to_str(r, 'cls') for r in classes])}\n"
f"The RDF graph supports the following relationships:\n"
f"{', '.join([self._res_to_str(r, 'rel') for r in relationships])}\n"
)
if self.standard == "rdf":
clss = self.query(cls_query_rdf)
rels = self.query(rel_query_rdf)
self.schema = _rdf_s_schema(clss, rels)
elif self.standard == "rdfs":
clss = self.query(cls_query_rdfs)
rels = self.query(rel_query_rdfs)
self.schema = _rdf_s_schema(clss, rels)
elif self.standard == "owl":
clss = self.query(cls_query_owl)
ops = self.query(op_query_owl)
dps = self.query(dp_query_owl)
self.schema = (
f"In the following, each IRI is followed by the local name and "
f"optionally its description in parentheses. \n"
f"The OWL graph supports the following node types:\n"
f"{', '.join([self._res_to_str(r, 'cls') for r in clss])}\n"
f"The OWL graph supports the following object properties, "
f"i.e., relationships between objects:\n"
f"{', '.join([self._res_to_str(r, 'op') for r in ops])}\n"
f"The OWL graph supports the following data properties, "
f"i.e., relationships between objects and literals:\n"
f"{', '.join([self._res_to_str(r, 'dp') for r in dps])}\n"
)
else:
raise ValueError(f"Mode '{self.standard}' is currently not supported.")

View File

@@ -0,0 +1,100 @@
from typing import Any, Dict, List, Optional
from langchain_community.graphs.graph_store import GraphStore
class TigerGraph(GraphStore):
"""TigerGraph wrapper for graph operations.
*Security note*: Make sure that the database connection uses credentials
that are narrowly-scoped to only include necessary permissions.
Failure to do so may result in data corruption or loss, since the calling
code may attempt commands that would result in deletion, mutation
of data if appropriately prompted or reading sensitive data if such
data is present in the database.
The best way to guard against such negative outcomes is to (as appropriate)
limit the permissions granted to the credentials used with this tool.
See https://python.langchain.com/docs/security for more information.
"""
def __init__(self, conn: Any) -> None:
"""Create a new TigerGraph graph wrapper instance."""
self.set_connection(conn)
self.set_schema()
@property
def conn(self) -> Any:
return self._conn
@property
def schema(self) -> Dict[str, Any]:
return self._schema
def get_schema(self) -> str: # type: ignore[override]
if self._schema:
return str(self._schema)
else:
self.set_schema()
return str(self._schema)
def set_connection(self, conn: Any) -> None:
try:
from pyTigerGraph import TigerGraphConnection
except ImportError:
raise ImportError(
"Could not import pyTigerGraph python package. "
"Please install it with `pip install pyTigerGraph`."
)
if not isinstance(conn, TigerGraphConnection):
msg = "**conn** parameter must inherit from TigerGraphConnection"
raise TypeError(msg)
if conn.ai.nlqs_host is None:
msg = """**conn** parameter does not have nlqs_host parameter defined.
Define hostname of NLQS service."""
raise ConnectionError(msg)
self._conn: TigerGraphConnection = conn
self.set_schema()
def set_schema(self, schema: Optional[Dict[str, Any]] = None) -> None:
"""
Set the schema of the TigerGraph Database.
Auto-generates Schema if **schema** is None.
"""
self._schema = self.generate_schema() if schema is None else schema
def generate_schema(
self,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Generates the schema of the TigerGraph Database and returns it
User can specify a **sample_ratio** (0 to 1) to determine the
ratio of documents/edges used (in relation to the Collection size)
to render each Collection Schema.
"""
return self._conn.getSchema(force=True)
def refresh_schema(self) -> None:
self.generate_schema()
def query(self, query: str) -> Dict[str, Any]: # type: ignore[override]
"""Query the TigerGraph database."""
answer = self._conn.ai.query(query)
return answer
def register_query(
self,
function_header: str,
description: str,
docstring: str,
param_types: dict = {},
) -> List[str]:
"""
Wrapper function to register a custom GSQL query to the TigerGraph NLQS.
"""
return self._conn.ai.registerCustomQuery(
function_header, description, docstring, param_types
)