@@ -529,10 +529,6 @@ def _validate_edge(self, edge: Edge):
529
529
if err is not None :
530
530
raise InvalidEdgeError (f"Collector output type does not match collector input type ({ edge } ): { err } " )
531
531
532
- # Validate that we are not connecting collector to iterator (currently unsupported)
533
- if isinstance (from_node , CollectInvocation ) and isinstance (to_node , IterateInvocation ):
534
- raise InvalidEdgeError (f"Cannot connect collector to iterator ({ edge } )" )
535
-
536
532
# Validate if collector output type matches input type (if this edge results in both being set) - skip if the destination field is not Any or list[Any]
537
533
if (
538
534
isinstance (from_node , CollectInvocation )
@@ -638,8 +634,10 @@ def _is_iterator_connection_valid(
638
634
if len (inputs ) > 1 :
639
635
return "Iterator may only have one input edge"
640
636
637
+ input_node = self .get_node (inputs [0 ].node_id )
638
+
641
639
# Get input and output fields (the fields linked to the iterator's input/output)
642
- input_field_type = get_output_field_type (self . get_node ( inputs [ 0 ]. node_id ) , inputs [0 ].field )
640
+ input_field_type = get_output_field_type (input_node , inputs [0 ].field )
643
641
output_field_types = [get_input_field_type (self .get_node (e .node_id ), e .field ) for e in outputs ]
644
642
645
643
# Input type must be a list
@@ -651,6 +649,22 @@ def _is_iterator_connection_valid(
651
649
if not all ((are_connection_types_compatible (input_field_item_type , t ) for t in output_field_types )):
652
650
return "Iterator outputs must connect to an input with a matching type"
653
651
652
+ # Collector input type must match all iterator output types
653
+ if isinstance (input_node , CollectInvocation ):
654
+ # Traverse the graph to find the first collector input edge. Collectors validate that their collection
655
+ # inputs are all of the same type, so we can use the first input edge to determine the collector's type
656
+ first_collector_input_edge = self ._get_input_edges (input_node .id , "item" )[0 ]
657
+ first_collector_input_type = get_output_field_type (
658
+ self .get_node (first_collector_input_edge .source .node_id ), first_collector_input_edge .source .field
659
+ )
660
+ resolved_collector_type = (
661
+ first_collector_input_type
662
+ if get_origin (first_collector_input_type ) is None
663
+ else get_args (first_collector_input_type )
664
+ )
665
+ if not all ((are_connection_types_compatible (resolved_collector_type , t ) for t in output_field_types )):
666
+ return "Iterator collection type must match all iterator output types"
667
+
654
668
return None
655
669
656
670
def _is_collector_connection_valid (
0 commit comments