@@ -457,13 +457,78 @@ async def run(self, text: str, examples: str = "", **kwargs: Any) -> GraphSchema
457
457
458
458
459
459
class SchemaFromExistingGraphExtractor (BaseSchemaBuilder ):
460
- """A class to build a GraphSchema object from an existing graph."""
460
+ """A class to build a GraphSchema object from an existing graph.
461
461
462
- def __init__ (self , driver : neo4j .Driver ) -> None :
462
+ Uses the get_structured_schema function to extract existing node labels,
463
+ relationship types, properties and existence constraints.
464
+
465
+ By default, the built schema does not allow any additional item (property,
466
+ node label, relationship type or pattern).
467
+
468
+ Args:
469
+ driver (neo4j.Driver): connection to the neo4j database.
470
+ additional_properties (bool, default False): see GraphSchema
471
+ additional_node_types (bool, default False): see GraphSchema
472
+ additional_relationship_types (bool, default False): see GraphSchema:
473
+ additional_patterns (bool, default False): see GraphSchema:
474
+ neo4j_database (Optional | str): name of the neo4j database to use
475
+ """
476
+
477
+ def __init__ (
478
+ self ,
479
+ driver : neo4j .Driver ,
480
+ additional_properties : bool = False ,
481
+ additional_node_types : bool = False ,
482
+ additional_relationship_types : bool = False ,
483
+ additional_patterns : bool = False ,
484
+ neo4j_database : Optional [str ] = None ,
485
+ ) -> None :
463
486
self .driver = driver
487
+ self .database = neo4j_database
488
+
489
+ self .additional_properties = additional_properties
490
+ self .additional_node_types = additional_node_types
491
+ self .additional_relationship_types = additional_relationship_types
492
+ self .additional_patterns = additional_patterns
493
+
494
+ @staticmethod
495
+ def _extract_required_properties (
496
+ structured_schema : dict [str , Any ],
497
+ ) -> list [tuple [str , str ]]:
498
+ """Extract a list of (node label (or rel type), property name) for which
499
+ an "EXISTENCE" or "KEY" constraint is defined in the DB.
500
+
501
+ Args:
502
+
503
+ structured_schema (dict[str, Any]): the result of the `get_structured_schema()` function.
504
+
505
+ Returns:
506
+
507
+ list of tuples of (node label (or rel type), property name)
508
+
509
+ """
510
+ schema_metadata = structured_schema .get ("metadata" , {})
511
+ existence_constraint = [] # list of (node label, property name)
512
+ for constraint in schema_metadata .get ("constraints" , []):
513
+ if constraint ["type" ] in (
514
+ "NODE_PROPERTY_EXISTENCE" ,
515
+ "NODE_KEY" ,
516
+ "RELATIONSHIP_PROPERTY_EXISTENCE" ,
517
+ "RELATIONSHIP_KEY" ,
518
+ ):
519
+ properties = constraint ["properties" ]
520
+ labels = constraint ["labelsOrTypes" ]
521
+ # note: existence constraint only apply to a single property
522
+ # and a single label
523
+ prop = properties [0 ]
524
+ lab = labels [0 ]
525
+ existence_constraint .append ((lab , prop ))
526
+ return existence_constraint
527
+
528
+ async def run (self ) -> GraphSchema :
529
+ structured_schema = get_structured_schema (self .driver , database = self .database )
530
+ existence_constraint = self ._extract_required_properties (structured_schema )
464
531
465
- async def run (self , ** kwargs : Any ) -> GraphSchema :
466
- structured_schema = get_structured_schema (self .driver )
467
532
node_labels = set (structured_schema ["node_props" ].keys ())
468
533
node_types = [
469
534
{
@@ -472,9 +537,11 @@ async def run(self, **kwargs: Any) -> GraphSchema:
472
537
{
473
538
"name" : p ["property" ],
474
539
"type" : p ["type" ],
540
+ "required" : (key , p ["property" ]) in existence_constraint ,
475
541
}
476
542
for p in properties
477
543
],
544
+ "additional_properties" : self .additional_properties ,
478
545
}
479
546
for key , properties in structured_schema ["node_props" ].items ()
480
547
]
@@ -486,6 +553,7 @@ async def run(self, **kwargs: Any) -> GraphSchema:
486
553
{
487
554
"name" : p ["property" ],
488
555
"type" : p ["type" ],
556
+ "required" : (key , p ["property" ]) in existence_constraint ,
489
557
}
490
558
for p in properties
491
559
],
@@ -524,5 +592,8 @@ async def run(self, **kwargs: Any) -> GraphSchema:
524
592
"node_types" : node_types ,
525
593
"relationship_types" : relationship_types ,
526
594
"patterns" : patterns ,
595
+ "additional_node_types" : self .additional_node_types ,
596
+ "additional_relationship_types" : self .additional_relationship_types ,
597
+ "additional_patterns" : self .additional_patterns ,
527
598
}
528
599
)
0 commit comments