Skip to content

Feature/schema from existing graph #355

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""This example demonstrates how to use the SchemaFromExistingGraphExtractor component
to automatically extract a schema from an existing Neo4j database.
"""

import asyncio

import neo4j

from neo4j_graphrag.experimental.components.schema import (
SchemaFromExistingGraphExtractor,
GraphSchema,
)


URI = "neo4j+s://demo.neo4jlabs.com"
AUTH = ("recommendations", "recommendations")
DATABASE = "recommendations"
INDEX = "moviePlotsEmbedding"


async def main() -> None:
"""Run the example."""

with neo4j.GraphDatabase.driver(
URI,
auth=AUTH,
) as driver:
extractor = SchemaFromExistingGraphExtractor(driver)
schema: GraphSchema = await extractor.run()
# schema.store_as_json("my_schema.json")
print(schema)


if __name__ == "__main__":
asyncio.run(main())
155 changes: 153 additions & 2 deletions src/neo4j_graphrag/experimental/components/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from __future__ import annotations

import json

import neo4j
import logging
import warnings
from typing import Any, Dict, List, Literal, Optional, Tuple, Union, Sequence
Expand Down Expand Up @@ -43,6 +45,7 @@
from neo4j_graphrag.generation import SchemaExtractionTemplate, PromptTemplate
from neo4j_graphrag.llm import LLMInterface
from neo4j_graphrag.utils.file_handler import FileHandler, FileFormat
from neo4j_graphrag.schema import get_structured_schema


class PropertyType(BaseModel):
Expand Down Expand Up @@ -270,7 +273,12 @@ def from_file(
raise SchemaValidationError(str(e)) from e


class SchemaBuilder(Component):
class BaseSchemaBuilder(Component):
async def run(self, *args: Any, **kwargs: Any) -> GraphSchema:
raise NotImplementedError()


class SchemaBuilder(BaseSchemaBuilder):
"""
A builder class for constructing GraphSchema objects from given entities,
relations, and their interrelationships defined in a potential schema.
Expand Down Expand Up @@ -379,7 +387,7 @@ async def run(
return self.create_schema_model(node_types, relationship_types, patterns)


class SchemaFromTextExtractor(Component):
class SchemaFromTextExtractor(BaseSchemaBuilder):
"""
A component for constructing GraphSchema objects from the output of an LLM after
automatic schema extraction from text.
Expand Down Expand Up @@ -462,3 +470,146 @@ async def run(self, text: str, examples: str = "", **kwargs: Any) -> GraphSchema
"patterns": extracted_patterns,
}
)


class SchemaFromExistingGraphExtractor(BaseSchemaBuilder):
"""A class to build a GraphSchema object from an existing graph.

Uses the get_structured_schema function to extract existing node labels,
relationship types, properties and existence constraints.

By default, the built schema does not allow any additional item (property,
node label, relationship type or pattern).

Args:
driver (neo4j.Driver): connection to the neo4j database.
additional_properties (bool, default False): see GraphSchema
additional_node_types (bool, default False): see GraphSchema
additional_relationship_types (bool, default False): see GraphSchema:
additional_patterns (bool, default False): see GraphSchema:
neo4j_database (Optional | str): name of the neo4j database to use
"""

def __init__(
self,
driver: neo4j.Driver,
additional_properties: bool = False,
additional_node_types: bool = False,
additional_relationship_types: bool = False,
additional_patterns: bool = False,
neo4j_database: Optional[str] = None,
) -> None:
self.driver = driver
self.database = neo4j_database

self.additional_properties = additional_properties
self.additional_node_types = additional_node_types
self.additional_relationship_types = additional_relationship_types
self.additional_patterns = additional_patterns

@staticmethod
def _extract_required_properties(
structured_schema: dict[str, Any],
) -> list[tuple[str, str]]:
"""Extract a list of (node label (or rel type), property name) for which
an "EXISTENCE" or "KEY" constraint is defined in the DB.

Args:

structured_schema (dict[str, Any]): the result of the `get_structured_schema()` function.

Returns:

list of tuples of (node label (or rel type), property name)

"""
schema_metadata = structured_schema.get("metadata", {})
existence_constraint = [] # list of (node label, property name)
for constraint in schema_metadata.get("constraints", []):
if constraint["type"] in (
"NODE_PROPERTY_EXISTENCE",
"NODE_KEY",
"RELATIONSHIP_PROPERTY_EXISTENCE",
"RELATIONSHIP_KEY",
):
properties = constraint["properties"]
labels = constraint["labelsOrTypes"]
# note: existence constraint only apply to a single property
# and a single label
prop = properties[0]
lab = labels[0]
existence_constraint.append((lab, prop))
return existence_constraint

async def run(self) -> GraphSchema:
structured_schema = get_structured_schema(self.driver, database=self.database)
existence_constraint = self._extract_required_properties(structured_schema)

node_labels = set(structured_schema["node_props"].keys())
node_types = [
{
"label": key,
"properties": [
{
"name": p["property"],
"type": p["type"],
"required": (key, p["property"]) in existence_constraint,
}
for p in properties
],
"additional_properties": self.additional_properties,
}
for key, properties in structured_schema["node_props"].items()
]
rel_labels = set(structured_schema["rel_props"].keys())
relationship_types = [
{
"label": key,
"properties": [
{
"name": p["property"],
"type": p["type"],
"required": (key, p["property"]) in existence_constraint,
}
for p in properties
],
}
for key, properties in structured_schema["rel_props"].items()
]
patterns = [
(s["start"], s["type"], s["end"])
for s in structured_schema["relationships"]
]
# deal with nodes and relationships without properties
for source, rel, target in patterns:
if source not in node_labels:
node_labels.add(source)
node_types.append(
{
"label": source,
}
)
if target not in node_labels:
node_labels.add(target)
node_types.append(
{
"label": target,
}
)
if rel not in rel_labels:
rel_labels.add(rel)
relationship_types.append(
{
"label": rel,
}
)
return GraphSchema.model_validate(
{
"node_types": node_types,
"relationship_types": relationship_types,
"patterns": patterns,
"additional_node_types": self.additional_node_types,
"additional_relationship_types": self.additional_relationship_types,
"additional_patterns": self.additional_patterns,
}
)
Loading