@@ -198,10 +198,10 @@ def list_ancestors(self):
198
198
ancestors += self .parent .list_ancestors ()
199
199
return ancestors
200
200
201
- def validate_node (self ):
201
+ def validate_node (self , * args ):
202
202
if self .children :
203
203
for child in self .children :
204
- child .validate_node ()
204
+ child .validate_node (* args )
205
205
206
206
# Returns all tokens in the span covered by this node
207
207
def list_tokens (self ):
@@ -304,8 +304,8 @@ class Root(Node):
304
304
def __init__ (self ):
305
305
super ().__init__ ("ROOT" )
306
306
307
- def validate_node (self ):
308
- super ().validate_node ()
307
+ def validate_node (self , * args ):
308
+ super ().validate_node (* args )
309
309
for child in self .children :
310
310
if type (child ) == Slot or type (child ) == Root :
311
311
raise TypeError (
@@ -324,8 +324,8 @@ class Intent(Node):
324
324
def __init__ (self , label ):
325
325
super ().__init__ (label )
326
326
327
- def validate_node (self ):
328
- super ().validate_node ()
327
+ def validate_node (self , * args ):
328
+ super ().validate_node (* args )
329
329
for child in self .children :
330
330
if type (child ) == Intent or type (child ) == Root :
331
331
raise TypeError (
@@ -342,8 +342,12 @@ class Slot(Node):
342
342
def __init__ (self , label ):
343
343
super ().__init__ (label )
344
344
345
- def validate_node (self ):
346
- super ().validate_node ()
345
+ def validate_node (self , allow_empty_slots = True , * args ):
346
+ super ().validate_node (* args )
347
+ if not allow_empty_slots :
348
+ if len (self .children ) == 0 :
349
+ raise TypeError ("Empty slot found: " + self .label )
350
+
347
351
for child in self .children :
348
352
if type (child ) == Slot or type (child ) == Root :
349
353
raise TypeError (
@@ -362,7 +366,7 @@ def __init__(self, label, index):
362
366
self .index = index
363
367
self .children = None
364
368
365
- def validate_node (self ):
369
+ def validate_node (self , * args ):
366
370
if self .children is not None :
367
371
raise TypeError (
368
372
"A token node is terminal and should not \
@@ -508,7 +512,7 @@ def __init__(
508
512
+ "Utterance is: {}" .format (utterance )
509
513
)
510
514
511
- def validate_tree (self ):
515
+ def validate_tree (self , allow_empty_slots = True ):
512
516
"""
513
517
This is a method for checking that roots/intents/slots are
514
518
nested correctly.
@@ -525,16 +529,16 @@ def validate_tree(self):
525
529
COMBINATION_SLOT_LABEL ,
526
530
)
527
531
)
528
- self .recursive_validation (self .root )
532
+ self .recursive_validation (self .root , allow_empty_slots )
529
533
except TypeError as t :
530
534
raise ValueError (
531
535
"Failed validation for {}" .format (self .root ) + "\n " + str (t )
532
536
)
533
537
534
- def recursive_validation (self , node ):
535
- node .validate_node ()
538
+ def recursive_validation (self , node , * args ):
539
+ node .validate_node (* args )
536
540
for child in node .children :
537
- child .validate_node ()
541
+ child .validate_node (* args )
538
542
539
543
def print_tree (self ):
540
544
print (self .flat_str ())
0 commit comments