Skip to content

Commit 4a52130

Browse files
committed
Extract required properties from existing constraints
1 parent 4bcfbb1 commit 4a52130

File tree

1 file changed

+75
-4
lines changed
  • src/neo4j_graphrag/experimental/components

1 file changed

+75
-4
lines changed

src/neo4j_graphrag/experimental/components/schema.py

Lines changed: 75 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -457,13 +457,78 @@ async def run(self, text: str, examples: str = "", **kwargs: Any) -> GraphSchema
457457

458458

459459
class SchemaFromExistingGraphExtractor(BaseSchemaBuilder):
460-
"""A class to build a GraphSchema object from an existing graph."""
460+
"""A class to build a GraphSchema object from an existing graph.
461461
462-
def __init__(self, driver: neo4j.Driver) -> None:
462+
Uses the get_structured_schema function to extract existing node labels,
463+
relationship types, properties and existence constraints.
464+
465+
By default, the built schema does not allow any additional item (property,
466+
node label, relationship type or pattern).
467+
468+
Args:
469+
driver (neo4j.Driver): connection to the neo4j database.
470+
additional_properties (bool, default False): see GraphSchema
471+
additional_node_types (bool, default False): see GraphSchema
472+
additional_relationship_types (bool, default False): see GraphSchema:
473+
additional_patterns (bool, default False): see GraphSchema:
474+
neo4j_database (Optional | str): name of the neo4j database to use
475+
"""
476+
477+
def __init__(
478+
self,
479+
driver: neo4j.Driver,
480+
additional_properties: bool = False,
481+
additional_node_types: bool = False,
482+
additional_relationship_types: bool = False,
483+
additional_patterns: bool = False,
484+
neo4j_database: Optional[str] = None,
485+
) -> None:
463486
self.driver = driver
487+
self.database = neo4j_database
488+
489+
self.additional_properties = additional_properties
490+
self.additional_node_types = additional_node_types
491+
self.additional_relationship_types = additional_relationship_types
492+
self.additional_patterns = additional_patterns
493+
494+
@staticmethod
495+
def _extract_required_properties(
496+
structured_schema: dict[str, Any],
497+
) -> list[tuple[str, str]]:
498+
"""Extract a list of (node label (or rel type), property name) for which
499+
an "EXISTENCE" or "KEY" constraint is defined in the DB.
500+
501+
Args:
502+
503+
structured_schema (dict[str, Any]): the result of the `get_structured_schema()` function.
504+
505+
Returns:
506+
507+
list of tuples of (node label (or rel type), property name)
508+
509+
"""
510+
schema_metadata = structured_schema.get("metadata", {})
511+
existence_constraint = [] # list of (node label, property name)
512+
for constraint in schema_metadata.get("constraints", []):
513+
if constraint["type"] in (
514+
"NODE_PROPERTY_EXISTENCE",
515+
"NODE_KEY",
516+
"RELATIONSHIP_PROPERTY_EXISTENCE",
517+
"RELATIONSHIP_KEY",
518+
):
519+
properties = constraint["properties"]
520+
labels = constraint["labelsOrTypes"]
521+
# note: existence constraint only apply to a single property
522+
# and a single label
523+
prop = properties[0]
524+
lab = labels[0]
525+
existence_constraint.append((lab, prop))
526+
return existence_constraint
527+
528+
async def run(self) -> GraphSchema:
529+
structured_schema = get_structured_schema(self.driver, database=self.database)
530+
existence_constraint = self._extract_required_properties(structured_schema)
464531

465-
async def run(self, **kwargs: Any) -> GraphSchema:
466-
structured_schema = get_structured_schema(self.driver)
467532
node_labels = set(structured_schema["node_props"].keys())
468533
node_types = [
469534
{
@@ -472,9 +537,11 @@ async def run(self, **kwargs: Any) -> GraphSchema:
472537
{
473538
"name": p["property"],
474539
"type": p["type"],
540+
"required": (key, p["property"]) in existence_constraint,
475541
}
476542
for p in properties
477543
],
544+
"additional_properties": self.additional_properties,
478545
}
479546
for key, properties in structured_schema["node_props"].items()
480547
]
@@ -486,6 +553,7 @@ async def run(self, **kwargs: Any) -> GraphSchema:
486553
{
487554
"name": p["property"],
488555
"type": p["type"],
556+
"required": (key, p["property"]) in existence_constraint,
489557
}
490558
for p in properties
491559
],
@@ -524,5 +592,8 @@ async def run(self, **kwargs: Any) -> GraphSchema:
524592
"node_types": node_types,
525593
"relationship_types": relationship_types,
526594
"patterns": patterns,
595+
"additional_node_types": self.additional_node_types,
596+
"additional_relationship_types": self.additional_relationship_types,
597+
"additional_patterns": self.additional_patterns,
527598
}
528599
)

0 commit comments

Comments
 (0)