Skip to content

Add support for FLUX ControlNet models (XLabs and InstantX) #7070

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 125 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
125 commits
Select commit Hold shift + click to select a range
2cc72b1
Add XLabs FLUX controlnet state dict key file to be used for developm…
RyanJDick Sep 30, 2024
2821ba8
Copy ControlNetFlux model from https://github.com/XLabs-AI/x-flux/blo…
RyanJDick Sep 30, 2024
58eba8b
Fix FLUX module imports for ControlNetFlux.
RyanJDick Oct 2, 2024
9758e5a
Remove duplicate FluxParams class.
RyanJDick Oct 2, 2024
eb5b662
Remove ControlNetFlux logic related to attn processor overrides.
RyanJDick Oct 2, 2024
9541156
Remove gradient checkpointing from ControlNetFlux.
RyanJDick Oct 2, 2024
62d12e6
Fix type errors and imporve docs for ControlNetFlux.
RyanJDick Oct 2, 2024
1d4a58e
Add FLUX XLabs ControlNet model probing.
RyanJDick Oct 2, 2024
c81bb76
First pass at integrating FLUX ControlNets into the FLUX Denoise invo…
RyanJDick Oct 3, 2024
36515e1
Add support for FLUX controlnet weight, begin_step_percent and end_st…
RyanJDick Oct 3, 2024
92b1515
Add InstantX FLUX ControlNet state dict for unit testing.
RyanJDick Oct 4, 2024
d1a0e99
Rename ControlNetFlux -> XLabsControlNetFlux
RyanJDick Oct 4, 2024
9bcb93f
Copy model from https://github.com/huggingface/diffusers/blob/99f6082…
RyanJDick Oct 4, 2024
459cf52
Start updating imports for FluxControlNetModel
RyanJDick Oct 4, 2024
3578010
Rename FluxControlNetModel -> DiffusersControlNetFlux
RyanJDick Oct 4, 2024
a17ea9b
Remove logic for modifying attn processors from DiffusersControlNetFlux.
RyanJDick Oct 4, 2024
e93da5d
Remove LoRA stuff from DiffusersCotnrolNetFlux.
RyanJDick Oct 4, 2024
a35b229
Remove FluxMultiControlNetModel
RyanJDick Oct 4, 2024
5bfd2ec
Remove gradient checkpointing from DiffusersControlNetFlux.
RyanJDick Oct 4, 2024
1795f4f
Fixup typing around DiffusersControlNetFluxOutput.
RyanJDick Oct 4, 2024
c0aab56
Remove DiffusersControlNetFlux.from_transformer(...).
RyanJDick Oct 4, 2024
8308e7d
Use top-level torch import for all torch stuff.
RyanJDick Oct 4, 2024
16cda33
Improve typing of zero_module().
RyanJDick Oct 4, 2024
1751c38
Migrate DiffusersControlNetFlux from diffusers-style to BFL-style.
RyanJDick Oct 4, 2024
4be3a33
Add utils for detecting XLabs ControlNet vs. InstantX ControlNet from
RyanJDick Oct 4, 2024
e733a1f
(minor) rename other_forward() -> forward()
RyanJDick Oct 4, 2024
c72c277
WIP - implement convert_diffusers_instantx_state_dict_to_bfl_format(.…
RyanJDick Oct 4, 2024
4ad135c
Finish first draft of convert_diffusers_instantx_state_dict_to_bfl_fo…
RyanJDick Oct 7, 2024
5872f05
Add unit test for convert_diffusers_instantx_state_dict_to_bfl_format…
RyanJDick Oct 7, 2024
2cd14dd
First pass of utility function to infer the FluxParams from a state d…
RyanJDick Oct 7, 2024
1a7eece
Add scripts/extract_sd_keys_and_shapes.py
RyanJDick Oct 7, 2024
728927e
Update FLUX ControlNet unit test state dicts to include shapes.
RyanJDick Oct 7, 2024
c762894
Add unit test for infer_flux_params_from_state_dict(...).
RyanJDick Oct 7, 2024
745b6db
Add unit test for infer_instantx_num_control_modes_from_state_dict().
RyanJDick Oct 7, 2024
80bc4eb
Add unit test to test the full flow of loading an InstantX ControlNet…
RyanJDick Oct 7, 2024
5673176
Update FluxControlnetModel to work with both XLabs and InstantX.
RyanJDick Oct 7, 2024
a24581e
Create flux/extensions directory.
RyanJDick Oct 7, 2024
bfc460a
Rename DiffusersControlNetFlux -> InstantXControlNetFlux.
RyanJDick Oct 7, 2024
f878e5e
Work on integrating InstantX into denoise process.
RyanJDick Oct 7, 2024
2f8f30b
Add instantx controlnet logic to FLUX model forward().
RyanJDick Oct 8, 2024
a783539
Fix circular imports related to XLabsControlNetFluxOutput and Instant…
RyanJDick Oct 8, 2024
5d11c30
Update ControlNetCheckpointProbe.get_base_type() to work with InstantX.
RyanJDick Oct 8, 2024
0dd9f1f
Bugfixes to get InstantX ControlNet working.
RyanJDick Oct 8, 2024
dea6cbd
Create a dedicated FLUX ControlNet invocation.
RyanJDick Oct 8, 2024
cd88723
Add instantx_control_mode param to FLUX ControlNet invocation.
RyanJDick Oct 8, 2024
c78eeb1
Shift the controlnet-type-specific logic into the specific ControlNet…
RyanJDick Oct 9, 2024
8bf8742
(minor) Add comment about future memory optimization.
RyanJDick Oct 9, 2024
216b36c
Try to fix test failures affecting MacOS CI runners.
RyanJDick Oct 9, 2024
2c92e8a
Revert "Try to fix test failures affecting MacOS CI runners."
RyanJDick Oct 9, 2024
6798bba
Skip tests that are failing on MacOS CI runners (for now).
RyanJDick Oct 9, 2024
8d1a458
Support installing InstantX ControlNet models from diffusers director…
RyanJDick Oct 9, 2024
859944f
Fix support for InstantX non-union models (with no single blocks).
RyanJDick Oct 9, 2024
9a8a858
update starter models to include FLUX controlnets
maryhipp Oct 9, 2024
3977ffa
update prepreprocessor logic to be more resilient
maryhipp Oct 9, 2024
e82d678
ui: enable controlnet controls when FLUX is main model, update schema
Oct 9, 2024
5f2279c
hide Control Mode for FLUX control net layer
Oct 9, 2024
8b1ef4b
Fix bug with InstantX input image range.
RyanJDick Oct 9, 2024
63a2e17
possibly a working FLUX controlnet graph
Oct 9, 2024
3953e60
Remove instantx_control_mode from FLUX ControlNet node.
RyanJDick Oct 9, 2024
b1567fe
Make FLUX controlnet node API more like SD API and get it working wit…
RyanJDick Oct 9, 2024
4aace24
Reduce peak memory utilization when preparing FLUX controlnet inputs.
RyanJDick Oct 10, 2024
563db67
Add XLabs FLUX controlnet state dict key file to be used for developm…
RyanJDick Sep 30, 2024
c594ef8
Copy ControlNetFlux model from https://github.com/XLabs-AI/x-flux/blo…
RyanJDick Sep 30, 2024
c1dfd9b
Fix FLUX module imports for ControlNetFlux.
RyanJDick Oct 2, 2024
2efaea8
Remove duplicate FluxParams class.
RyanJDick Oct 2, 2024
5307248
Remove ControlNetFlux logic related to attn processor overrides.
RyanJDick Oct 2, 2024
69c0d7d
Remove gradient checkpointing from ControlNetFlux.
RyanJDick Oct 2, 2024
0b84f56
Fix type errors and imporve docs for ControlNetFlux.
RyanJDick Oct 2, 2024
3f3aba8
Add FLUX XLabs ControlNet model probing.
RyanJDick Oct 2, 2024
0f93dea
First pass at integrating FLUX ControlNets into the FLUX Denoise invo…
RyanJDick Oct 3, 2024
93e98a1
Add support for FLUX controlnet weight, begin_step_percent and end_st…
RyanJDick Oct 3, 2024
bce5a93
Add InstantX FLUX ControlNet state dict for unit testing.
RyanJDick Oct 4, 2024
e7dc439
Rename ControlNetFlux -> XLabsControlNetFlux
RyanJDick Oct 4, 2024
7d56a8c
Copy model from https://github.com/huggingface/diffusers/blob/99f6082…
RyanJDick Oct 4, 2024
30f6034
Start updating imports for FluxControlNetModel
RyanJDick Oct 4, 2024
bc0ded0
Rename FluxControlNetModel -> DiffusersControlNetFlux
RyanJDick Oct 4, 2024
8121843
Remove logic for modifying attn processors from DiffusersControlNetFlux.
RyanJDick Oct 4, 2024
ca30acc
Remove LoRA stuff from DiffusersCotnrolNetFlux.
RyanJDick Oct 4, 2024
7049566
Remove FluxMultiControlNetModel
RyanJDick Oct 4, 2024
7a22819
Remove gradient checkpointing from DiffusersControlNetFlux.
RyanJDick Oct 4, 2024
5f44559
Fixup typing around DiffusersControlNetFluxOutput.
RyanJDick Oct 4, 2024
704e747
Remove DiffusersControlNetFlux.from_transformer(...).
RyanJDick Oct 4, 2024
83f4700
Use top-level torch import for all torch stuff.
RyanJDick Oct 4, 2024
7562ea4
Improve typing of zero_module().
RyanJDick Oct 4, 2024
cb33de3
Migrate DiffusersControlNetFlux from diffusers-style to BFL-style.
RyanJDick Oct 4, 2024
1e43389
Add utils for detecting XLabs ControlNet vs. InstantX ControlNet from
RyanJDick Oct 4, 2024
3dfc242
(minor) rename other_forward() -> forward()
RyanJDick Oct 4, 2024
76f4766
WIP - implement convert_diffusers_instantx_state_dict_to_bfl_format(.…
RyanJDick Oct 4, 2024
a9e7eca
Finish first draft of convert_diffusers_instantx_state_dict_to_bfl_fo…
RyanJDick Oct 7, 2024
24c1156
Add unit test for convert_diffusers_instantx_state_dict_to_bfl_format…
RyanJDick Oct 7, 2024
03cf953
First pass of utility function to infer the FluxParams from a state d…
RyanJDick Oct 7, 2024
683504b
Add scripts/extract_sd_keys_and_shapes.py
RyanJDick Oct 7, 2024
1125218
Update FLUX ControlNet unit test state dicts to include shapes.
RyanJDick Oct 7, 2024
1cdd501
Add unit test for infer_flux_params_from_state_dict(...).
RyanJDick Oct 7, 2024
b76555b
Add unit test for infer_instantx_num_control_modes_from_state_dict().
RyanJDick Oct 7, 2024
2f4d3cb
Add unit test to test the full flow of loading an InstantX ControlNet…
RyanJDick Oct 7, 2024
714dd5f
Update FluxControlnetModel to work with both XLabs and InstantX.
RyanJDick Oct 7, 2024
d75ac56
Create flux/extensions directory.
RyanJDick Oct 7, 2024
44c588d
Rename DiffusersControlNetFlux -> InstantXControlNetFlux.
RyanJDick Oct 7, 2024
c8d1d14
Work on integrating InstantX into denoise process.
RyanJDick Oct 7, 2024
4289b5e
Add instantx controlnet logic to FLUX model forward().
RyanJDick Oct 8, 2024
47c7df3
Fix circular imports related to XLabsControlNetFluxOutput and Instant…
RyanJDick Oct 8, 2024
ce4624f
Update ControlNetCheckpointProbe.get_base_type() to work with InstantX.
RyanJDick Oct 8, 2024
de414c0
Bugfixes to get InstantX ControlNet working.
RyanJDick Oct 8, 2024
e854181
Create a dedicated FLUX ControlNet invocation.
RyanJDick Oct 8, 2024
d99e7dd
Add instantx_control_mode param to FLUX ControlNet invocation.
RyanJDick Oct 8, 2024
0559480
Shift the controlnet-type-specific logic into the specific ControlNet…
RyanJDick Oct 9, 2024
812940e
(minor) Add comment about future memory optimization.
RyanJDick Oct 9, 2024
b3b7d08
Try to fix test failures affecting MacOS CI runners.
RyanJDick Oct 9, 2024
118f0ba
Revert "Try to fix test failures affecting MacOS CI runners."
RyanJDick Oct 9, 2024
8c12568
Skip tests that are failing on MacOS CI runners (for now).
RyanJDick Oct 9, 2024
63c4ac5
Support installing InstantX ControlNet models from diffusers director…
RyanJDick Oct 9, 2024
908f656
Fix support for InstantX non-union models (with no single blocks).
RyanJDick Oct 9, 2024
7daf3b7
update starter models to include FLUX controlnets
maryhipp Oct 9, 2024
5fd3c39
update prepreprocessor logic to be more resilient
maryhipp Oct 9, 2024
eacdfc6
ui: enable controlnet controls when FLUX is main model, update schema
Oct 9, 2024
5141be8
hide Control Mode for FLUX control net layer
Oct 9, 2024
e7f9086
Fix bug with InstantX input image range.
RyanJDick Oct 9, 2024
2bd4466
possibly a working FLUX controlnet graph
Oct 9, 2024
a654dad
Remove instantx_control_mode from FLUX ControlNet node.
RyanJDick Oct 9, 2024
bb947c6
Make FLUX controlnet node API more like SD API and get it working wit…
RyanJDick Oct 9, 2024
f37c501
Reduce peak memory utilization when preparing FLUX controlnet inputs.
RyanJDick Oct 10, 2024
7f7d8e5
Merge branch 'ryan/flux-controlnet-xlabs-instantx' of https://github.…
hipsterusername Oct 10, 2024
bb6d073
Use the Shakker-Labs ControlNet union model as the only FLUX Control…
RyanJDick Oct 10, 2024
683f9a7
Restore instantx_control_mode field on FLUX ControlNet invocation.
RyanJDick Oct 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions invokeai/app/invocations/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ class FieldDescriptions:
freeu_s2 = 'Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process.'
freeu_b1 = "Scaling factor for stage 1 to amplify the contributions of backbone features."
freeu_b2 = "Scaling factor for stage 2 to amplify the contributions of backbone features."
instantx_control_mode = "The control mode for InstantX ControlNet union models. Ignored for other ControlNet models. The standard mapping is: canny (0), tile (1), depth (2), blur (3), pose (4), gray (5), low quality (6). Negative values will be treated as 'None'."


class ImageField(BaseModel):
Expand Down
99 changes: 99 additions & 0 deletions invokeai/app/invocations/flux_controlnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from pydantic import BaseModel, Field, field_validator, model_validator

from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, OutputField, UIType
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import CONTROLNET_RESIZE_VALUES


class FluxControlNetField(BaseModel):
image: ImageField = Field(description="The control image")
control_model: ModelIdentifierField = Field(description="The ControlNet model to use")
control_weight: float | list[float] = Field(default=1, description="The weight given to the ControlNet")
begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
)
end_step_percent: float = Field(
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
)
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
instantx_control_mode: int | None = Field(default=-1, description=FieldDescriptions.instantx_control_mode)

@field_validator("control_weight")
@classmethod
def validate_control_weight(cls, v: float | list[float]) -> float | list[float]:
validate_weights(v)
return v

@model_validator(mode="after")
def validate_begin_end_step_percent(self):
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
return self


@invocation_output("flux_controlnet_output")
class FluxControlNetOutput(BaseInvocationOutput):
"""FLUX ControlNet info"""

control: FluxControlNetField = OutputField(description=FieldDescriptions.control)


@invocation(
"flux_controlnet",
title="FLUX ControlNet",
tags=["controlnet", "flux"],
category="controlnet",
version="1.0.0",
classification=Classification.Prototype,
)
class FluxControlNetInvocation(BaseInvocation):
"""Collect FLUX ControlNet info to pass to other nodes."""

image: ImageField = InputField(description="The control image")
control_model: ModelIdentifierField = InputField(
description=FieldDescriptions.controlnet_model, ui_type=UIType.ControlNetModel
)
control_weight: float | list[float] = InputField(
default=1.0, ge=-1, le=2, description="The weight given to the ControlNet"
)
begin_step_percent: float = InputField(
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
)
end_step_percent: float = InputField(
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
)
resize_mode: CONTROLNET_RESIZE_VALUES = InputField(default="just_resize", description="The resize mode used")
# Note: We default to -1 instead of None, because in the workflow editor UI None is not currently supported.
instantx_control_mode: int | None = InputField(default=-1, description=FieldDescriptions.instantx_control_mode)

@field_validator("control_weight")
@classmethod
def validate_control_weight(cls, v: float | list[float]) -> float | list[float]:
validate_weights(v)
return v

@model_validator(mode="after")
def validate_begin_end_step_percent(self):
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
return self

def invoke(self, context: InvocationContext) -> FluxControlNetOutput:
return FluxControlNetOutput(
control=FluxControlNetField(
image=self.image,
control_model=self.control_model,
control_weight=self.control_weight,
begin_step_percent=self.begin_step_percent,
end_step_percent=self.end_step_percent,
resize_mode=self.resize_mode,
instantx_control_mode=self.instantx_control_mode,
),
)
140 changes: 130 additions & 10 deletions invokeai/app/invocations/flux_denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,16 @@
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.model import TransformerField
from invokeai.app.invocations.flux_controlnet import FluxControlNetField
from invokeai.app.invocations.model import TransformerField, VAEField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.controlnet.instantx_controlnet_flux import InstantXControlNetFlux
from invokeai.backend.flux.controlnet.xlabs_controlnet_flux import XLabsControlNetFlux
from invokeai.backend.flux.denoise import denoise
from invokeai.backend.flux.inpaint_extension import InpaintExtension
from invokeai.backend.flux.extensions.inpaint_extension import InpaintExtension
from invokeai.backend.flux.extensions.instantx_controlnet_extension import InstantXControlNetExtension
from invokeai.backend.flux.extensions.xlabs_controlnet_extension import XLabsControlNetExtension
from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.sampling_utils import (
clip_timestep_schedule_fractional,
Expand All @@ -44,7 +49,7 @@
title="FLUX Denoise",
tags=["image", "flux"],
category="image",
version="3.0.0",
version="3.1.0",
classification=Classification.Prototype,
)
class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
Expand Down Expand Up @@ -87,6 +92,13 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
description="The guidance strength. Higher values adhere more strictly to the prompt, and will produce less diverse images. FLUX dev only, ignored for schnell.",
)
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
control: FluxControlNetField | list[FluxControlNetField] | None = InputField(
default=None, input=Input.Connection, description="ControlNet models."
)
controlnet_vae: VAEField | None = InputField(
description=FieldDescriptions.vae,
input=Input.Connection,
)

@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
Expand Down Expand Up @@ -167,8 +179,8 @@ def _run_diffusion(

inpaint_mask = self._prep_inpaint_mask(context, x)

b, _c, h, w = x.shape
img_ids = generate_img_ids(h=h, w=w, batch_size=b, device=x.device, dtype=x.dtype)
b, _c, latent_h, latent_w = x.shape
img_ids = generate_img_ids(h=latent_h, w=latent_w, batch_size=b, device=x.device, dtype=x.dtype)

bs, t5_seq_len, _ = t5_embeddings.shape
txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device())
Expand All @@ -192,12 +204,21 @@ def _run_diffusion(
noise=noise,
)

with (
transformer_info.model_on_device() as (cached_weights, transformer),
ExitStack() as exit_stack,
):
assert isinstance(transformer, Flux)
with ExitStack() as exit_stack:
# Prepare ControlNet extensions.
# Note: We do this before loading the transformer model to minimize peak memory (see implementation).
controlnet_extensions = self._prep_controlnet_extensions(
context=context,
exit_stack=exit_stack,
latent_height=latent_h,
latent_width=latent_w,
dtype=inference_dtype,
device=x.device,
)

# Load the transformer model.
(cached_weights, transformer) = exit_stack.enter_context(transformer_info.model_on_device())
assert isinstance(transformer, Flux)
config = transformer_info.config
assert config is not None

Expand Down Expand Up @@ -242,6 +263,7 @@ def _run_diffusion(
step_callback=self._build_step_callback(context),
guidance=self.guidance,
inpaint_extension=inpaint_extension,
controlnet_extensions=controlnet_extensions,
)

x = unpack(x.float(), self.height, self.width)
Expand Down Expand Up @@ -288,6 +310,104 @@ def _prep_inpaint_mask(self, context: InvocationContext, latents: torch.Tensor)
# `latents`.
return mask.expand_as(latents)

def _prep_controlnet_extensions(
self,
context: InvocationContext,
exit_stack: ExitStack,
latent_height: int,
latent_width: int,
dtype: torch.dtype,
device: torch.device,
) -> list[XLabsControlNetExtension | InstantXControlNetExtension]:
# Normalize the controlnet input to list[ControlField].
controlnets: list[FluxControlNetField]
if self.control is None:
controlnets = []
elif isinstance(self.control, FluxControlNetField):
controlnets = [self.control]
elif isinstance(self.control, list):
controlnets = self.control
else:
raise ValueError(f"Unsupported controlnet type: {type(self.control)}")

# TODO(ryand): Add a field to the model config so that we can distinguish between XLabs and InstantX ControlNets
# before loading the models. Then make sure that all VAE encoding is done before loading the ControlNets to
# minimize peak memory.

# First, load the ControlNet models so that we can determine the ControlNet types.
controlnet_models = [context.models.load(controlnet.control_model) for controlnet in controlnets]

# Calculate the controlnet conditioning tensors.
# We do this before loading the ControlNet models because it may require running the VAE, and we are trying to
# keep peak memory down.
controlnet_conds: list[torch.Tensor] = []
for controlnet, controlnet_model in zip(controlnets, controlnet_models, strict=True):
image = context.images.get_pil(controlnet.image.image_name)
if isinstance(controlnet_model.model, InstantXControlNetFlux):
if self.controlnet_vae is None:
raise ValueError("A ControlNet VAE is required when using an InstantX FLUX ControlNet.")
vae_info = context.models.load(self.controlnet_vae.vae)
controlnet_conds.append(
InstantXControlNetExtension.prepare_controlnet_cond(
controlnet_image=image,
vae_info=vae_info,
latent_height=latent_height,
latent_width=latent_width,
dtype=dtype,
device=device,
resize_mode=controlnet.resize_mode,
)
)
elif isinstance(controlnet_model.model, XLabsControlNetFlux):
controlnet_conds.append(
XLabsControlNetExtension.prepare_controlnet_cond(
controlnet_image=image,
latent_height=latent_height,
latent_width=latent_width,
dtype=dtype,
device=device,
resize_mode=controlnet.resize_mode,
)
)

# Finally, load the ControlNet models and initialize the ControlNet extensions.
controlnet_extensions: list[XLabsControlNetExtension | InstantXControlNetExtension] = []
for controlnet, controlnet_cond, controlnet_model in zip(
controlnets, controlnet_conds, controlnet_models, strict=True
):
model = exit_stack.enter_context(controlnet_model)

if isinstance(model, XLabsControlNetFlux):
controlnet_extensions.append(
XLabsControlNetExtension(
model=model,
controlnet_cond=controlnet_cond,
weight=controlnet.control_weight,
begin_step_percent=controlnet.begin_step_percent,
end_step_percent=controlnet.end_step_percent,
)
)
elif isinstance(model, InstantXControlNetFlux):
instantx_control_mode: torch.Tensor | None = None
if controlnet.instantx_control_mode is not None and controlnet.instantx_control_mode >= 0:
instantx_control_mode = torch.tensor(controlnet.instantx_control_mode, dtype=torch.long)
instantx_control_mode = instantx_control_mode.reshape([-1, 1])

controlnet_extensions.append(
InstantXControlNetExtension(
model=model,
controlnet_cond=controlnet_cond,
instantx_control_mode=instantx_control_mode,
weight=controlnet.control_weight,
begin_step_percent=controlnet.begin_step_percent,
end_step_percent=controlnet.end_step_percent,
)
)
else:
raise ValueError(f"Unsupported ControlNet model type: {type(model)}")

return controlnet_extensions

def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.transformer.loras:
lora_info = context.models.load(lora.lora)
Expand Down
Empty file.
58 changes: 58 additions & 0 deletions invokeai/backend/flux/controlnet/controlnet_flux_output.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from dataclasses import dataclass

import torch


@dataclass
class ControlNetFluxOutput:
single_block_residuals: list[torch.Tensor] | None
double_block_residuals: list[torch.Tensor] | None

def apply_weight(self, weight: float):
if self.single_block_residuals is not None:
for i in range(len(self.single_block_residuals)):
self.single_block_residuals[i] = self.single_block_residuals[i] * weight
if self.double_block_residuals is not None:
for i in range(len(self.double_block_residuals)):
self.double_block_residuals[i] = self.double_block_residuals[i] * weight


def add_tensor_lists_elementwise(
list1: list[torch.Tensor] | None, list2: list[torch.Tensor] | None
) -> list[torch.Tensor] | None:
"""Add two tensor lists elementwise that could be None."""
if list1 is None and list2 is None:
return None
if list1 is None:
return list2
if list2 is None:
return list1

new_list: list[torch.Tensor] = []
for list1_tensor, list2_tensor in zip(list1, list2, strict=True):
new_list.append(list1_tensor + list2_tensor)
return new_list


def add_controlnet_flux_outputs(
controlnet_output_1: ControlNetFluxOutput, controlnet_output_2: ControlNetFluxOutput
) -> ControlNetFluxOutput:
return ControlNetFluxOutput(
single_block_residuals=add_tensor_lists_elementwise(
controlnet_output_1.single_block_residuals, controlnet_output_2.single_block_residuals
),
double_block_residuals=add_tensor_lists_elementwise(
controlnet_output_1.double_block_residuals, controlnet_output_2.double_block_residuals
),
)


def sum_controlnet_flux_outputs(
controlnet_outputs: list[ControlNetFluxOutput],
) -> ControlNetFluxOutput:
controlnet_output_sum = ControlNetFluxOutput(single_block_residuals=None, double_block_residuals=None)

for controlnet_output in controlnet_outputs:
controlnet_output_sum = add_controlnet_flux_outputs(controlnet_output_sum, controlnet_output)

return controlnet_output_sum
Loading
Loading