Skip to content

Commit 526158a

Browse files
feat(nodes): support collect -> iterate node connections w/ validation
1 parent d805896 commit 526158a

File tree

1 file changed

+19
-5
lines changed

1 file changed

+19
-5
lines changed

invokeai/app/services/shared/graph.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -529,10 +529,6 @@ def _validate_edge(self, edge: Edge):
529529
if err is not None:
530530
raise InvalidEdgeError(f"Collector output type does not match collector input type ({edge}): {err}")
531531

532-
# Validate that we are not connecting collector to iterator (currently unsupported)
533-
if isinstance(from_node, CollectInvocation) and isinstance(to_node, IterateInvocation):
534-
raise InvalidEdgeError(f"Cannot connect collector to iterator ({edge})")
535-
536532
# Validate if collector output type matches input type (if this edge results in both being set) - skip if the destination field is not Any or list[Any]
537533
if (
538534
isinstance(from_node, CollectInvocation)
@@ -638,8 +634,10 @@ def _is_iterator_connection_valid(
638634
if len(inputs) > 1:
639635
return "Iterator may only have one input edge"
640636

637+
input_node = self.get_node(inputs[0].node_id)
638+
641639
# Get input and output fields (the fields linked to the iterator's input/output)
642-
input_field_type = get_output_field_type(self.get_node(inputs[0].node_id), inputs[0].field)
640+
input_field_type = get_output_field_type(input_node, inputs[0].field)
643641
output_field_types = [get_input_field_type(self.get_node(e.node_id), e.field) for e in outputs]
644642

645643
# Input type must be a list
@@ -651,6 +649,22 @@ def _is_iterator_connection_valid(
651649
if not all((are_connection_types_compatible(input_field_item_type, t) for t in output_field_types)):
652650
return "Iterator outputs must connect to an input with a matching type"
653651

652+
# Collector input type must match all iterator output types
653+
if isinstance(input_node, CollectInvocation):
654+
# Traverse the graph to find the first collector input edge. Collectors validate that their collection
655+
# inputs are all of the same type, so we can use the first input edge to determine the collector's type
656+
first_collector_input_edge = self._get_input_edges(input_node.id, "item")[0]
657+
first_collector_input_type = get_output_field_type(
658+
self.get_node(first_collector_input_edge.source.node_id), first_collector_input_edge.source.field
659+
)
660+
resolved_collector_type = (
661+
first_collector_input_type
662+
if get_origin(first_collector_input_type) is None
663+
else get_args(first_collector_input_type)
664+
)
665+
if not all((are_connection_types_compatible(resolved_collector_type, t) for t in output_field_types)):
666+
return "Iterator collection type must match all iterator output types"
667+
654668
return None
655669

656670
def _is_collector_connection_valid(

0 commit comments

Comments
 (0)