Skip to content

feat: better graph validation errors #7614

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Feb 5, 2025
163 changes: 86 additions & 77 deletions invokeai/app/services/shared/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,18 @@ class Edge(BaseModel):
source: EdgeConnection = Field(description="The connection for the edge's from node and field")
destination: EdgeConnection = Field(description="The connection for the edge's to node and field")

def __str__(self):
return f"{self.source.node_id}.{self.source.field} -> {self.destination.node_id}.{self.destination.field}"

def get_output_field(node: BaseInvocation, field: str) -> Any:

def get_output_field_type(node: BaseInvocation, field: str) -> Any:
node_type = type(node)
node_outputs = get_type_hints(node_type.get_output_annotation())
node_output_field = node_outputs.get(field) or None
return node_output_field


def get_input_field(node: BaseInvocation, field: str) -> Any:
def get_input_field_type(node: BaseInvocation, field: str) -> Any:
node_type = type(node)
node_inputs = get_type_hints(node_type)
node_input_field = node_inputs.get(field) or None
Expand Down Expand Up @@ -93,6 +96,10 @@ def is_list_or_contains_list(t):
return False


def is_any(t: Any) -> bool:
return t == Any or Any in get_args(t)


def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool:
if not from_type:
return False
Expand All @@ -102,13 +109,7 @@ def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool:
# TODO: this is pretty forgiving on generic types. Clean that up (need to handle optionals and such)
if from_type and to_type:
# Ports are compatible
if (
from_type == to_type
or from_type == Any
or to_type == Any
or Any in get_args(from_type)
or Any in get_args(to_type)
):
if from_type == to_type or is_any(from_type) or is_any(to_type):
return True

if from_type in get_args(to_type):
Expand Down Expand Up @@ -140,10 +141,10 @@ def are_connections_compatible(
"""Determines if a connection between fields of two nodes is compatible."""

# TODO: handle iterators and collectors
from_node_field = get_output_field(from_node, from_field)
to_node_field = get_input_field(to_node, to_field)
from_type = get_output_field_type(from_node, from_field)
to_type = get_input_field_type(to_node, to_field)

return are_connection_types_compatible(from_node_field, to_node_field)
return are_connection_types_compatible(from_type, to_type)


T = TypeVar("T")
Expand Down Expand Up @@ -440,17 +441,19 @@ def validate_self(self) -> None:
self.get_node(edge.destination.node_id),
edge.destination.field,
):
raise InvalidEdgeError(
f"Invalid edge from {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
)
raise InvalidEdgeError(f"Edge source and target types do not match ({edge})")

# Validate all iterators & collectors
# TODO: may need to validate all iterators & collectors in subgraphs so edge connections in parent graphs will be available
for node in self.nodes.values():
if isinstance(node, IterateInvocation) and not self._is_iterator_connection_valid(node.id):
raise InvalidEdgeError(f"Invalid iterator node {node.id}")
if isinstance(node, CollectInvocation) and not self._is_collector_connection_valid(node.id):
raise InvalidEdgeError(f"Invalid collector node {node.id}")
if isinstance(node, IterateInvocation):
err = self._is_iterator_connection_valid(node.id)
if err is not None:
raise InvalidEdgeError(f"Invalid iterator node ({node.id}): {err}")
if isinstance(node, CollectInvocation):
err = self._is_collector_connection_valid(node.id)
if err is not None:
raise InvalidEdgeError(f"Invalid collector node ({node.id}): {err}")

return None

Expand All @@ -477,11 +480,11 @@ def is_valid(self) -> bool:

def _is_destination_field_Any(self, edge: Edge) -> bool:
"""Checks if the destination field for an edge is of type typing.Any"""
return get_input_field(self.get_node(edge.destination.node_id), edge.destination.field) == Any
return get_input_field_type(self.get_node(edge.destination.node_id), edge.destination.field) == Any

def _is_destination_field_list_of_Any(self, edge: Edge) -> bool:
"""Checks if the destination field for an edge is of type typing.Any"""
return get_input_field(self.get_node(edge.destination.node_id), edge.destination.field) == list[Any]
return get_input_field_type(self.get_node(edge.destination.node_id), edge.destination.field) == list[Any]

def _validate_edge(self, edge: Edge):
"""Validates that a new edge doesn't create a cycle in the graph"""
Expand All @@ -491,55 +494,40 @@ def _validate_edge(self, edge: Edge):
from_node = self.get_node(edge.source.node_id)
to_node = self.get_node(edge.destination.node_id)
except NodeNotFoundError:
raise InvalidEdgeError("One or both nodes don't exist: {edge.source.node_id} -> {edge.destination.node_id}")
raise InvalidEdgeError(f"One or both nodes don't exist ({edge})")

# Validate that an edge to this node+field doesn't already exist
input_edges = self._get_input_edges(edge.destination.node_id, edge.destination.field)
if len(input_edges) > 0 and not isinstance(to_node, CollectInvocation):
raise InvalidEdgeError(
f"Edge to node {edge.destination.node_id} field {edge.destination.field} already exists"
)
raise InvalidEdgeError(f"Edge already exists ({edge})")

# Validate that no cycles would be created
g = self.nx_graph_flat()
g.add_edge(edge.source.node_id, edge.destination.node_id)
if not nx.is_directed_acyclic_graph(g):
raise InvalidEdgeError(
f"Edge creates a cycle in the graph: {edge.source.node_id} -> {edge.destination.node_id}"
)
raise InvalidEdgeError(f"Edge creates a cycle in the graph ({edge})")

# Validate that the field types are compatible
if not are_connections_compatible(from_node, edge.source.field, to_node, edge.destination.field):
raise InvalidEdgeError(
f"Fields are incompatible: cannot connect {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
)
raise InvalidEdgeError(f"Field types are incompatible ({edge})")

# Validate if iterator output type matches iterator input type (if this edge results in both being set)
if isinstance(to_node, IterateInvocation) and edge.destination.field == "collection":
if not self._is_iterator_connection_valid(edge.destination.node_id, new_input=edge.source):
raise InvalidEdgeError(
f"Iterator input type does not match iterator output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
)
err = self._is_iterator_connection_valid(edge.destination.node_id, new_input=edge.source)
if err is not None:
raise InvalidEdgeError(f"Iterator input type does not match iterator output type ({edge}): {err}")

# Validate if iterator input type matches output type (if this edge results in both being set)
if isinstance(from_node, IterateInvocation) and edge.source.field == "item":
if not self._is_iterator_connection_valid(edge.source.node_id, new_output=edge.destination):
raise InvalidEdgeError(
f"Iterator output type does not match iterator input type:, {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
)
err = self._is_iterator_connection_valid(edge.source.node_id, new_output=edge.destination)
if err is not None:
raise InvalidEdgeError(f"Iterator output type does not match iterator input type ({edge}): {err}")

# Validate if collector input type matches output type (if this edge results in both being set)
if isinstance(to_node, CollectInvocation) and edge.destination.field == "item":
if not self._is_collector_connection_valid(edge.destination.node_id, new_input=edge.source):
raise InvalidEdgeError(
f"Collector output type does not match collector input type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
)

# Validate that we are not connecting collector to iterator (currently unsupported)
if isinstance(from_node, CollectInvocation) and isinstance(to_node, IterateInvocation):
raise InvalidEdgeError(
f"Cannot connect collector to iterator: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
)
err = self._is_collector_connection_valid(edge.destination.node_id, new_input=edge.source)
if err is not None:
raise InvalidEdgeError(f"Collector output type does not match collector input type ({edge}): {err}")

# Validate if collector output type matches input type (if this edge results in both being set) - skip if the destination field is not Any or list[Any]
if (
Expand All @@ -548,10 +536,9 @@ def _validate_edge(self, edge: Edge):
and not self._is_destination_field_list_of_Any(edge)
and not self._is_destination_field_Any(edge)
):
if not self._is_collector_connection_valid(edge.source.node_id, new_output=edge.destination):
raise InvalidEdgeError(
f"Collector input type does not match collector output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
)
err = self._is_collector_connection_valid(edge.source.node_id, new_output=edge.destination)
if err is not None:
raise InvalidEdgeError(f"Collector input type does not match collector output type ({edge}): {err}")

def has_node(self, node_id: str) -> bool:
"""Determines whether or not a node exists in the graph."""
Expand Down Expand Up @@ -634,7 +621,7 @@ def _is_iterator_connection_valid(
node_id: str,
new_input: Optional[EdgeConnection] = None,
new_output: Optional[EdgeConnection] = None,
) -> bool:
) -> str | None:
inputs = [e.source for e in self._get_input_edges(node_id, "collection")]
outputs = [e.destination for e in self._get_output_edges(node_id, "item")]

Expand All @@ -645,29 +632,47 @@ def _is_iterator_connection_valid(

# Only one input is allowed for iterators
if len(inputs) > 1:
return False
return "Iterator may only have one input edge"

input_node = self.get_node(inputs[0].node_id)

# Get input and output fields (the fields linked to the iterator's input/output)
input_field = get_output_field(self.get_node(inputs[0].node_id), inputs[0].field)
output_fields = [get_input_field(self.get_node(e.node_id), e.field) for e in outputs]
input_field_type = get_output_field_type(input_node, inputs[0].field)
output_field_types = [get_input_field_type(self.get_node(e.node_id), e.field) for e in outputs]

# Input type must be a list
if get_origin(input_field) is not list:
return False
if get_origin(input_field_type) is not list:
return "Iterator input must be a collection"

# Validate that all outputs match the input type
input_field_item_type = get_args(input_field)[0]
if not all((are_connection_types_compatible(input_field_item_type, f) for f in output_fields)):
return False
input_field_item_type = get_args(input_field_type)[0]
if not all((are_connection_types_compatible(input_field_item_type, t) for t in output_field_types)):
return "Iterator outputs must connect to an input with a matching type"

# Collector input type must match all iterator output types
if isinstance(input_node, CollectInvocation):
# Traverse the graph to find the first collector input edge. Collectors validate that their collection
# inputs are all of the same type, so we can use the first input edge to determine the collector's type
first_collector_input_edge = self._get_input_edges(input_node.id, "item")[0]
first_collector_input_type = get_output_field_type(
self.get_node(first_collector_input_edge.source.node_id), first_collector_input_edge.source.field
)
resolved_collector_type = (
first_collector_input_type
if get_origin(first_collector_input_type) is None
else get_args(first_collector_input_type)
)
if not all((are_connection_types_compatible(resolved_collector_type, t) for t in output_field_types)):
return "Iterator collection type must match all iterator output types"

return True
return None

def _is_collector_connection_valid(
self,
node_id: str,
new_input: Optional[EdgeConnection] = None,
new_output: Optional[EdgeConnection] = None,
) -> bool:
) -> str | None:
inputs = [e.source for e in self._get_input_edges(node_id, "item")]
outputs = [e.destination for e in self._get_output_edges(node_id, "collection")]

Expand All @@ -677,38 +682,42 @@ def _is_collector_connection_valid(
outputs.append(new_output)

# Get input and output fields (the fields linked to the iterator's input/output)
input_fields = [get_output_field(self.get_node(e.node_id), e.field) for e in inputs]
output_fields = [get_input_field(self.get_node(e.node_id), e.field) for e in outputs]
input_field_types = [get_output_field_type(self.get_node(e.node_id), e.field) for e in inputs]
output_field_types = [get_input_field_type(self.get_node(e.node_id), e.field) for e in outputs]

# Validate that all inputs are derived from or match a single type
input_field_types = {
t
for input_field in input_fields
for t in ([input_field] if get_origin(input_field) is None else get_args(input_field))
if t != NoneType
resolved_type
for input_field_type in input_field_types
for resolved_type in (
[input_field_type] if get_origin(input_field_type) is None else get_args(input_field_type)
)
if resolved_type != NoneType
} # Get unique types
type_tree = nx.DiGraph()
type_tree.add_nodes_from(input_field_types)
type_tree.add_edges_from([e for e in itertools.permutations(input_field_types, 2) if issubclass(e[1], e[0])])
type_degrees = type_tree.in_degree(type_tree.nodes)
if sum((t[1] == 0 for t in type_degrees)) != 1: # type: ignore
return False # There is more than one root type
return "Collector input collection items must be of a single type"

# Get the input root type
input_root_type = next(t[0] for t in type_degrees if t[1] == 0) # type: ignore

# Verify that all outputs are lists
if not all(is_list_or_contains_list(f) for f in output_fields):
return False
if not all(is_list_or_contains_list(t) or is_any(t) for t in output_field_types):
return "Collector output must connect to a collection input"

# Verify that all outputs match the input type (are a base class or the same class)
if not all(
is_union_subtype(input_root_type, get_args(f)[0]) or issubclass(input_root_type, get_args(f)[0])
for f in output_fields
is_any(t)
or is_union_subtype(input_root_type, get_args(t)[0])
or issubclass(input_root_type, get_args(t)[0])
for t in output_field_types
):
return False
return "Collector outputs must connect to a collection input with a matching type"

return True
return None

def nx_graph(self) -> nx.DiGraph:
"""Returns a NetworkX DiGraph representing the layout of this graph"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'
import { zPydanticValidationError } from 'features/system/store/zodSchemas';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { truncate, upperFirst } from 'lodash-es';
import { truncate } from 'lodash-es';
import { serializeError } from 'serialize-error';
import { queueApi } from 'services/api/endpoints/queue';
import type { JsonObject } from 'type-fest';
Expand Down Expand Up @@ -52,15 +52,12 @@ export const addBatchEnqueuedListener = (startAppListening: AppStartListening) =
const result = zPydanticValidationError.safeParse(response);
if (result.success) {
result.data.data.detail.map((e) => {
const description = truncate(e.msg.replace(/^(Value|Index|Key) error, /i, ''), { length: 256 });
toast({
id: 'QUEUE_BATCH_FAILED',
title: truncate(upperFirst(e.msg), { length: 128 }),
title: t('queue.batchFailedToQueue'),
status: 'error',
description: truncate(
`Path:
${e.loc.join('.')}`,
{ length: 128 }
),
description,
});
});
} else if (response.status !== 403) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,6 @@ describe(validateConnectionTypes.name, () => {
});

describe('special cases', () => {
it('should reject a COLLECTION input to a COLLECTION input', () => {
const r = validateConnectionTypes(
{ name: 'CollectionField', cardinality: 'COLLECTION', batch: false },
{ name: 'CollectionField', cardinality: 'COLLECTION', batch: false }
);
expect(r).toBe(false);
});

it('should accept equal types', () => {
const r = validateConnectionTypes(
{ name: 'IntegerField', cardinality: 'SINGLE', batch: false },
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,6 @@ import { type FieldType, isCollection, isSingle, isSingleOrCollection } from 'fe
* @returns True if the connection is valid, false otherwise.
*/
export const validateConnectionTypes = (sourceType: FieldType, targetType: FieldType) => {
// TODO: There's a bug with Collect -> Iterate nodes:
// https://github.com/invoke-ai/InvokeAI/issues/3956
// Once this is resolved, we can remove this check.
if (sourceType.name === 'CollectionField' && targetType.name === 'CollectionField') {
return false;
}

if (areTypesEqual(sourceType, targetType)) {
return true;
}
Expand Down
Loading