Skip to content

Commit 2d82e69

Browse files
authored
Add support for FLUX ControlNet models (XLabs and InstantX) (#7070)
## Summary Add support for FLUX ControlNet models (XLabs and InstantX). ## QA Instructions - [x] SD1 and SDXL ControlNets, since the ModelLoaderRegistry calls were changed. - [x] Single Xlabs controlnet - [x] Single InstantX union controlnet - [x] Single InstantX controlnet - [x] Single Shakker Labs Union controlnet - [x] Multiple controlnets - [x] Weight, start, end params all work as expected - [x] Can be used with image-to-image and inpainting. - [x] Clear error message if no VAE is passed when using InstantX controlnet. - [x] Install InstantX ControlNet in diffusers format from HF repo (`InstantX/FLUX.1-dev-Controlnet-Union`) - [x] Test all FLUX ControlNet starter models ## Merge Plan No special instructions. ## Checklist - [x] _The PR has a short but descriptive title, suitable for a changelog_ - [x] _Tests added / updated (if applicable)_ - [ ] _Documentation added / updated (if applicable)_
2 parents 236c065 + 683f9a7 commit 2d82e69

30 files changed

+2245
-47
lines changed

invokeai/app/invocations/fields.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,7 @@ class FieldDescriptions:
192192
freeu_s2 = 'Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process.'
193193
freeu_b1 = "Scaling factor for stage 1 to amplify the contributions of backbone features."
194194
freeu_b2 = "Scaling factor for stage 2 to amplify the contributions of backbone features."
195+
instantx_control_mode = "The control mode for InstantX ControlNet union models. Ignored for other ControlNet models. The standard mapping is: canny (0), tile (1), depth (2), blur (3), pose (4), gray (5), low quality (6). Negative values will be treated as 'None'."
195196

196197

197198
class ImageField(BaseModel):
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
from pydantic import BaseModel, Field, field_validator, model_validator
2+
3+
from invokeai.app.invocations.baseinvocation import (
4+
BaseInvocation,
5+
BaseInvocationOutput,
6+
Classification,
7+
invocation,
8+
invocation_output,
9+
)
10+
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, OutputField, UIType
11+
from invokeai.app.invocations.model import ModelIdentifierField
12+
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
13+
from invokeai.app.services.shared.invocation_context import InvocationContext
14+
from invokeai.app.util.controlnet_utils import CONTROLNET_RESIZE_VALUES
15+
16+
17+
class FluxControlNetField(BaseModel):
18+
image: ImageField = Field(description="The control image")
19+
control_model: ModelIdentifierField = Field(description="The ControlNet model to use")
20+
control_weight: float | list[float] = Field(default=1, description="The weight given to the ControlNet")
21+
begin_step_percent: float = Field(
22+
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
23+
)
24+
end_step_percent: float = Field(
25+
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
26+
)
27+
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
28+
instantx_control_mode: int | None = Field(default=-1, description=FieldDescriptions.instantx_control_mode)
29+
30+
@field_validator("control_weight")
31+
@classmethod
32+
def validate_control_weight(cls, v: float | list[float]) -> float | list[float]:
33+
validate_weights(v)
34+
return v
35+
36+
@model_validator(mode="after")
37+
def validate_begin_end_step_percent(self):
38+
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
39+
return self
40+
41+
42+
@invocation_output("flux_controlnet_output")
43+
class FluxControlNetOutput(BaseInvocationOutput):
44+
"""FLUX ControlNet info"""
45+
46+
control: FluxControlNetField = OutputField(description=FieldDescriptions.control)
47+
48+
49+
@invocation(
50+
"flux_controlnet",
51+
title="FLUX ControlNet",
52+
tags=["controlnet", "flux"],
53+
category="controlnet",
54+
version="1.0.0",
55+
classification=Classification.Prototype,
56+
)
57+
class FluxControlNetInvocation(BaseInvocation):
58+
"""Collect FLUX ControlNet info to pass to other nodes."""
59+
60+
image: ImageField = InputField(description="The control image")
61+
control_model: ModelIdentifierField = InputField(
62+
description=FieldDescriptions.controlnet_model, ui_type=UIType.ControlNetModel
63+
)
64+
control_weight: float | list[float] = InputField(
65+
default=1.0, ge=-1, le=2, description="The weight given to the ControlNet"
66+
)
67+
begin_step_percent: float = InputField(
68+
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
69+
)
70+
end_step_percent: float = InputField(
71+
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
72+
)
73+
resize_mode: CONTROLNET_RESIZE_VALUES = InputField(default="just_resize", description="The resize mode used")
74+
# Note: We default to -1 instead of None, because in the workflow editor UI None is not currently supported.
75+
instantx_control_mode: int | None = InputField(default=-1, description=FieldDescriptions.instantx_control_mode)
76+
77+
@field_validator("control_weight")
78+
@classmethod
79+
def validate_control_weight(cls, v: float | list[float]) -> float | list[float]:
80+
validate_weights(v)
81+
return v
82+
83+
@model_validator(mode="after")
84+
def validate_begin_end_step_percent(self):
85+
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
86+
return self
87+
88+
def invoke(self, context: InvocationContext) -> FluxControlNetOutput:
89+
return FluxControlNetOutput(
90+
control=FluxControlNetField(
91+
image=self.image,
92+
control_model=self.control_model,
93+
control_weight=self.control_weight,
94+
begin_step_percent=self.begin_step_percent,
95+
end_step_percent=self.end_step_percent,
96+
resize_mode=self.resize_mode,
97+
instantx_control_mode=self.instantx_control_mode,
98+
),
99+
)

invokeai/app/invocations/flux_denoise.py

Lines changed: 130 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,16 @@
1616
WithBoard,
1717
WithMetadata,
1818
)
19-
from invokeai.app.invocations.model import TransformerField
19+
from invokeai.app.invocations.flux_controlnet import FluxControlNetField
20+
from invokeai.app.invocations.model import TransformerField, VAEField
2021
from invokeai.app.invocations.primitives import LatentsOutput
2122
from invokeai.app.services.shared.invocation_context import InvocationContext
23+
from invokeai.backend.flux.controlnet.instantx_controlnet_flux import InstantXControlNetFlux
24+
from invokeai.backend.flux.controlnet.xlabs_controlnet_flux import XLabsControlNetFlux
2225
from invokeai.backend.flux.denoise import denoise
23-
from invokeai.backend.flux.inpaint_extension import InpaintExtension
26+
from invokeai.backend.flux.extensions.inpaint_extension import InpaintExtension
27+
from invokeai.backend.flux.extensions.instantx_controlnet_extension import InstantXControlNetExtension
28+
from invokeai.backend.flux.extensions.xlabs_controlnet_extension import XLabsControlNetExtension
2429
from invokeai.backend.flux.model import Flux
2530
from invokeai.backend.flux.sampling_utils import (
2631
clip_timestep_schedule_fractional,
@@ -44,7 +49,7 @@
4449
title="FLUX Denoise",
4550
tags=["image", "flux"],
4651
category="image",
47-
version="3.0.0",
52+
version="3.1.0",
4853
classification=Classification.Prototype,
4954
)
5055
class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
@@ -87,6 +92,13 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
8792
description="The guidance strength. Higher values adhere more strictly to the prompt, and will produce less diverse images. FLUX dev only, ignored for schnell.",
8893
)
8994
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
95+
control: FluxControlNetField | list[FluxControlNetField] | None = InputField(
96+
default=None, input=Input.Connection, description="ControlNet models."
97+
)
98+
controlnet_vae: VAEField | None = InputField(
99+
description=FieldDescriptions.vae,
100+
input=Input.Connection,
101+
)
90102

91103
@torch.no_grad()
92104
def invoke(self, context: InvocationContext) -> LatentsOutput:
@@ -167,8 +179,8 @@ def _run_diffusion(
167179

168180
inpaint_mask = self._prep_inpaint_mask(context, x)
169181

170-
b, _c, h, w = x.shape
171-
img_ids = generate_img_ids(h=h, w=w, batch_size=b, device=x.device, dtype=x.dtype)
182+
b, _c, latent_h, latent_w = x.shape
183+
img_ids = generate_img_ids(h=latent_h, w=latent_w, batch_size=b, device=x.device, dtype=x.dtype)
172184

173185
bs, t5_seq_len, _ = t5_embeddings.shape
174186
txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device())
@@ -192,12 +204,21 @@ def _run_diffusion(
192204
noise=noise,
193205
)
194206

195-
with (
196-
transformer_info.model_on_device() as (cached_weights, transformer),
197-
ExitStack() as exit_stack,
198-
):
199-
assert isinstance(transformer, Flux)
207+
with ExitStack() as exit_stack:
208+
# Prepare ControlNet extensions.
209+
# Note: We do this before loading the transformer model to minimize peak memory (see implementation).
210+
controlnet_extensions = self._prep_controlnet_extensions(
211+
context=context,
212+
exit_stack=exit_stack,
213+
latent_height=latent_h,
214+
latent_width=latent_w,
215+
dtype=inference_dtype,
216+
device=x.device,
217+
)
200218

219+
# Load the transformer model.
220+
(cached_weights, transformer) = exit_stack.enter_context(transformer_info.model_on_device())
221+
assert isinstance(transformer, Flux)
201222
config = transformer_info.config
202223
assert config is not None
203224

@@ -242,6 +263,7 @@ def _run_diffusion(
242263
step_callback=self._build_step_callback(context),
243264
guidance=self.guidance,
244265
inpaint_extension=inpaint_extension,
266+
controlnet_extensions=controlnet_extensions,
245267
)
246268

247269
x = unpack(x.float(), self.height, self.width)
@@ -288,6 +310,104 @@ def _prep_inpaint_mask(self, context: InvocationContext, latents: torch.Tensor)
288310
# `latents`.
289311
return mask.expand_as(latents)
290312

313+
def _prep_controlnet_extensions(
314+
self,
315+
context: InvocationContext,
316+
exit_stack: ExitStack,
317+
latent_height: int,
318+
latent_width: int,
319+
dtype: torch.dtype,
320+
device: torch.device,
321+
) -> list[XLabsControlNetExtension | InstantXControlNetExtension]:
322+
# Normalize the controlnet input to list[ControlField].
323+
controlnets: list[FluxControlNetField]
324+
if self.control is None:
325+
controlnets = []
326+
elif isinstance(self.control, FluxControlNetField):
327+
controlnets = [self.control]
328+
elif isinstance(self.control, list):
329+
controlnets = self.control
330+
else:
331+
raise ValueError(f"Unsupported controlnet type: {type(self.control)}")
332+
333+
# TODO(ryand): Add a field to the model config so that we can distinguish between XLabs and InstantX ControlNets
334+
# before loading the models. Then make sure that all VAE encoding is done before loading the ControlNets to
335+
# minimize peak memory.
336+
337+
# First, load the ControlNet models so that we can determine the ControlNet types.
338+
controlnet_models = [context.models.load(controlnet.control_model) for controlnet in controlnets]
339+
340+
# Calculate the controlnet conditioning tensors.
341+
# We do this before loading the ControlNet models because it may require running the VAE, and we are trying to
342+
# keep peak memory down.
343+
controlnet_conds: list[torch.Tensor] = []
344+
for controlnet, controlnet_model in zip(controlnets, controlnet_models, strict=True):
345+
image = context.images.get_pil(controlnet.image.image_name)
346+
if isinstance(controlnet_model.model, InstantXControlNetFlux):
347+
if self.controlnet_vae is None:
348+
raise ValueError("A ControlNet VAE is required when using an InstantX FLUX ControlNet.")
349+
vae_info = context.models.load(self.controlnet_vae.vae)
350+
controlnet_conds.append(
351+
InstantXControlNetExtension.prepare_controlnet_cond(
352+
controlnet_image=image,
353+
vae_info=vae_info,
354+
latent_height=latent_height,
355+
latent_width=latent_width,
356+
dtype=dtype,
357+
device=device,
358+
resize_mode=controlnet.resize_mode,
359+
)
360+
)
361+
elif isinstance(controlnet_model.model, XLabsControlNetFlux):
362+
controlnet_conds.append(
363+
XLabsControlNetExtension.prepare_controlnet_cond(
364+
controlnet_image=image,
365+
latent_height=latent_height,
366+
latent_width=latent_width,
367+
dtype=dtype,
368+
device=device,
369+
resize_mode=controlnet.resize_mode,
370+
)
371+
)
372+
373+
# Finally, load the ControlNet models and initialize the ControlNet extensions.
374+
controlnet_extensions: list[XLabsControlNetExtension | InstantXControlNetExtension] = []
375+
for controlnet, controlnet_cond, controlnet_model in zip(
376+
controlnets, controlnet_conds, controlnet_models, strict=True
377+
):
378+
model = exit_stack.enter_context(controlnet_model)
379+
380+
if isinstance(model, XLabsControlNetFlux):
381+
controlnet_extensions.append(
382+
XLabsControlNetExtension(
383+
model=model,
384+
controlnet_cond=controlnet_cond,
385+
weight=controlnet.control_weight,
386+
begin_step_percent=controlnet.begin_step_percent,
387+
end_step_percent=controlnet.end_step_percent,
388+
)
389+
)
390+
elif isinstance(model, InstantXControlNetFlux):
391+
instantx_control_mode: torch.Tensor | None = None
392+
if controlnet.instantx_control_mode is not None and controlnet.instantx_control_mode >= 0:
393+
instantx_control_mode = torch.tensor(controlnet.instantx_control_mode, dtype=torch.long)
394+
instantx_control_mode = instantx_control_mode.reshape([-1, 1])
395+
396+
controlnet_extensions.append(
397+
InstantXControlNetExtension(
398+
model=model,
399+
controlnet_cond=controlnet_cond,
400+
instantx_control_mode=instantx_control_mode,
401+
weight=controlnet.control_weight,
402+
begin_step_percent=controlnet.begin_step_percent,
403+
end_step_percent=controlnet.end_step_percent,
404+
)
405+
)
406+
else:
407+
raise ValueError(f"Unsupported ControlNet model type: {type(model)}")
408+
409+
return controlnet_extensions
410+
291411
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]:
292412
for lora in self.transformer.loras:
293413
lora_info = context.models.load(lora.lora)

invokeai/backend/flux/controlnet/__init__.py

Whitespace-only changes.
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
from dataclasses import dataclass
2+
3+
import torch
4+
5+
6+
@dataclass
7+
class ControlNetFluxOutput:
8+
single_block_residuals: list[torch.Tensor] | None
9+
double_block_residuals: list[torch.Tensor] | None
10+
11+
def apply_weight(self, weight: float):
12+
if self.single_block_residuals is not None:
13+
for i in range(len(self.single_block_residuals)):
14+
self.single_block_residuals[i] = self.single_block_residuals[i] * weight
15+
if self.double_block_residuals is not None:
16+
for i in range(len(self.double_block_residuals)):
17+
self.double_block_residuals[i] = self.double_block_residuals[i] * weight
18+
19+
20+
def add_tensor_lists_elementwise(
21+
list1: list[torch.Tensor] | None, list2: list[torch.Tensor] | None
22+
) -> list[torch.Tensor] | None:
23+
"""Add two tensor lists elementwise that could be None."""
24+
if list1 is None and list2 is None:
25+
return None
26+
if list1 is None:
27+
return list2
28+
if list2 is None:
29+
return list1
30+
31+
new_list: list[torch.Tensor] = []
32+
for list1_tensor, list2_tensor in zip(list1, list2, strict=True):
33+
new_list.append(list1_tensor + list2_tensor)
34+
return new_list
35+
36+
37+
def add_controlnet_flux_outputs(
38+
controlnet_output_1: ControlNetFluxOutput, controlnet_output_2: ControlNetFluxOutput
39+
) -> ControlNetFluxOutput:
40+
return ControlNetFluxOutput(
41+
single_block_residuals=add_tensor_lists_elementwise(
42+
controlnet_output_1.single_block_residuals, controlnet_output_2.single_block_residuals
43+
),
44+
double_block_residuals=add_tensor_lists_elementwise(
45+
controlnet_output_1.double_block_residuals, controlnet_output_2.double_block_residuals
46+
),
47+
)
48+
49+
50+
def sum_controlnet_flux_outputs(
51+
controlnet_outputs: list[ControlNetFluxOutput],
52+
) -> ControlNetFluxOutput:
53+
controlnet_output_sum = ControlNetFluxOutput(single_block_residuals=None, double_block_residuals=None)
54+
55+
for controlnet_output in controlnet_outputs:
56+
controlnet_output_sum = add_controlnet_flux_outputs(controlnet_output_sum, controlnet_output)
57+
58+
return controlnet_output_sum

0 commit comments

Comments
 (0)