Skip to content

Commit 4d5f74c

Browse files
authored
LoRA refactor to enable FLUX control LoRAs w/ quantized tranformers (#7446)
## Summary This PR refactors the LoRA handling code to enable the use of FLUX control LoRAs on top of quantized transformers. Changes: - Renamed a bunch of the model patching utilities to reflect that they are not LoRA-specific - Improved the unit test coverage. - Refactored the handling of 'sidecar' patch layers to make them work with more layer patch types. (This was necessary to get FLUX control LoRAs working on top of quantized models.) - Removed `ONNXModelPatcher`. It is out-of-date and hasn't been used in a while. ## QA Instructions I completed the following tests. **These should be repeated after changing the target branch to main.** **Due to the large surface area of this PR, reviewers should do regression tests on a range of LoRA formats. There is a risk of regression on a specific format that was missed during the refactoring.** - [x] FLUX Control LoRA + full FLUX transformer - [x] FLUX Control LoRA + BnB NF4 quantized transformer - [x] FLUX Control LoRA + GGUF quantized transformer - [x] FLUX Control LoRA + non-control LoRA + full FLUX transformer - [x] FLUX Contro LoRA + non-control LoRA + BnB quantized transformer - [x] FLUX Control LoRA + non-control LoRA + GGUF quantized transformer - Test the following cases for regression: - [x] Misc SD1/SDXL LoRA variants (LoRA, LoKr, IA3) - [x] FLUX, non-quantized, variety of LoRA formats - [x] FLUX, quantized, variety of LoRA formats ## Merge Plan **_Don't merge this PR yet._** Merge plan: 1. First merge brandon/flux-tools-loras into main 2. Change the target branch of this PR to main 3. Review / test / merge this PR ## Checklist - [x] _The PR has a short but descriptive title, suitable for a changelog_ - [x] _Tests added / updated (if applicable)_ - [x] _Documentation added / updated (if applicable)_ - [ ] _Updated `What's New` copy (if doing a release after this PR)_
2 parents 5f41a69 + dd09509 commit 4d5f74c

File tree

68 files changed

+962
-664
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

68 files changed

+962
-664
lines changed

invokeai/app/invocations/compel.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@
1919
from invokeai.app.invocations.primitives import ConditioningOutput
2020
from invokeai.app.services.shared.invocation_context import InvocationContext
2121
from invokeai.app.util.ti_utils import generate_ti_list
22-
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
23-
from invokeai.backend.lora.lora_patcher import LoRAPatcher
2422
from invokeai.backend.model_patcher import ModelPatcher
23+
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
24+
from invokeai.backend.patches.model_patcher import LayerPatcher
2525
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
2626
BasicConditioningInfo,
2727
ConditioningFieldData,
@@ -66,10 +66,10 @@ def invoke(self, context: InvocationContext) -> ConditioningOutput:
6666
tokenizer_info = context.models.load(self.clip.tokenizer)
6767
text_encoder_info = context.models.load(self.clip.text_encoder)
6868

69-
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
69+
def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
7070
for lora in self.clip.loras:
7171
lora_info = context.models.load(lora.lora)
72-
assert isinstance(lora_info.model, LoRAModelRaw)
72+
assert isinstance(lora_info.model, ModelPatchRaw)
7373
yield (lora_info.model, lora.weight)
7474
del lora_info
7575
return
@@ -82,7 +82,7 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
8282
# apply all patches while the model is on the target device
8383
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
8484
tokenizer_info as tokenizer,
85-
LoRAPatcher.apply_lora_patches(
85+
LayerPatcher.apply_model_patches(
8686
model=text_encoder,
8787
patches=_lora_loader(),
8888
prefix="lora_te_",
@@ -162,11 +162,11 @@ def run_clip_compel(
162162
c_pooled = None
163163
return c, c_pooled
164164

165-
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
165+
def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
166166
for lora in clip_field.loras:
167167
lora_info = context.models.load(lora.lora)
168168
lora_model = lora_info.model
169-
assert isinstance(lora_model, LoRAModelRaw)
169+
assert isinstance(lora_model, ModelPatchRaw)
170170
yield (lora_model, lora.weight)
171171
del lora_info
172172
return
@@ -179,7 +179,7 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
179179
# apply all patches while the model is on the target device
180180
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
181181
tokenizer_info as tokenizer,
182-
LoRAPatcher.apply_lora_patches(
182+
LayerPatcher.apply_model_patches(
183183
text_encoder,
184184
patches=_lora_loader(),
185185
prefix=lora_prefix,

invokeai/app/invocations/denoise_latents.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,10 @@
3737
from invokeai.app.services.shared.invocation_context import InvocationContext
3838
from invokeai.app.util.controlnet_utils import prepare_control_image
3939
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
40-
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
41-
from invokeai.backend.lora.lora_patcher import LoRAPatcher
4240
from invokeai.backend.model_manager import BaseModelType, ModelVariantType
4341
from invokeai.backend.model_patcher import ModelPatcher
42+
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
43+
from invokeai.backend.patches.model_patcher import LayerPatcher
4444
from invokeai.backend.stable_diffusion import PipelineIntermediateState
4545
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, DenoiseInputs
4646
from invokeai.backend.stable_diffusion.diffusers_pipeline import (
@@ -987,10 +987,10 @@ def _old_invoke(self, context: InvocationContext) -> LatentsOutput:
987987
def step_callback(state: PipelineIntermediateState) -> None:
988988
context.util.sd_step_callback(state, unet_config.base)
989989

990-
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
990+
def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
991991
for lora in self.unet.loras:
992992
lora_info = context.models.load(lora.lora)
993-
assert isinstance(lora_info.model, LoRAModelRaw)
993+
assert isinstance(lora_info.model, ModelPatchRaw)
994994
yield (lora_info.model, lora.weight)
995995
del lora_info
996996
return
@@ -1003,7 +1003,7 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
10031003
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
10041004
SeamlessExt.static_patch_model(unet, self.unet.seamless_axes), # FIXME
10051005
# Apply the LoRA after unet has been moved to its target device for faster patching.
1006-
LoRAPatcher.apply_lora_patches(
1006+
LayerPatcher.apply_model_patches(
10071007
model=unet,
10081008
patches=_lora_loader(),
10091009
prefix="lora_unet_",

invokeai/app/invocations/flux_denoise.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,10 @@
4747
unpack,
4848
)
4949
from invokeai.backend.flux.text_conditioning import FluxTextConditioning
50-
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
51-
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
52-
from invokeai.backend.lora.lora_patcher import LoRAPatcher
5350
from invokeai.backend.model_manager.config import ModelFormat
51+
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
52+
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
53+
from invokeai.backend.patches.model_patcher import LayerPatcher
5454
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
5555
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
5656
from invokeai.backend.util.devices import TorchDevice
@@ -306,7 +306,7 @@ def _run_diffusion(
306306
if config.format in [ModelFormat.Checkpoint]:
307307
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
308308
exit_stack.enter_context(
309-
LoRAPatcher.apply_lora_patches(
309+
LayerPatcher.apply_model_patches(
310310
model=transformer,
311311
patches=self._lora_iterator(context),
312312
prefix=FLUX_LORA_TRANSFORMER_PREFIX,
@@ -321,7 +321,7 @@ def _run_diffusion(
321321
# The model is quantized, so apply the LoRA weights as sidecar layers. This results in slower inference,
322322
# than directly patching the weights, but is agnostic to the quantization format.
323323
exit_stack.enter_context(
324-
LoRAPatcher.apply_lora_sidecar_patches(
324+
LayerPatcher.apply_model_sidecar_patches(
325325
model=transformer,
326326
patches=self._lora_iterator(context),
327327
prefix=FLUX_LORA_TRANSFORMER_PREFIX,
@@ -715,15 +715,15 @@ def _prep_ip_adapter_extensions(
715715

716716
return pos_ip_adapter_extensions, neg_ip_adapter_extensions
717717

718-
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]:
718+
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[ModelPatchRaw, float]]:
719719
loras: list[Union[LoRAField, ControlLoRAField]] = [*self.transformer.loras]
720720
if self.control_lora:
721721
# Note: Since FLUX structural control LoRAs modify the shape of some weights, it is important that they are
722722
# applied last.
723723
loras.append(self.control_lora)
724724
for lora in loras:
725725
lora_info = context.models.load(lora.lora)
726-
assert isinstance(lora_info.model, LoRAModelRaw)
726+
assert isinstance(lora_info.model, ModelPatchRaw)
727727
yield (lora_info.model, lora.weight)
728728
del lora_info
729729

invokeai/app/invocations/flux_text_encoder.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717
from invokeai.app.invocations.primitives import FluxConditioningOutput
1818
from invokeai.app.services.shared.invocation_context import InvocationContext
1919
from invokeai.backend.flux.modules.conditioner import HFEncoder
20-
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX
21-
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
22-
from invokeai.backend.lora.lora_patcher import LoRAPatcher
2320
from invokeai.backend.model_manager.config import ModelFormat
21+
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX
22+
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
23+
from invokeai.backend.patches.model_patcher import LayerPatcher
2424
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo
2525

2626

@@ -111,7 +111,7 @@ def _clip_encode(self, context: InvocationContext) -> torch.Tensor:
111111
if clip_text_encoder_config.format in [ModelFormat.Diffusers]:
112112
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
113113
exit_stack.enter_context(
114-
LoRAPatcher.apply_lora_patches(
114+
LayerPatcher.apply_model_patches(
115115
model=clip_text_encoder,
116116
patches=self._clip_lora_iterator(context),
117117
prefix=FLUX_LORA_CLIP_PREFIX,
@@ -130,9 +130,9 @@ def _clip_encode(self, context: InvocationContext) -> torch.Tensor:
130130
assert isinstance(pooled_prompt_embeds, torch.Tensor)
131131
return pooled_prompt_embeds
132132

133-
def _clip_lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]:
133+
def _clip_lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[ModelPatchRaw, float]]:
134134
for lora in self.clip.loras:
135135
lora_info = context.models.load(lora.lora)
136-
assert isinstance(lora_info.model, LoRAModelRaw)
136+
assert isinstance(lora_info.model, ModelPatchRaw)
137137
yield (lora_info.model, lora.weight)
138138
del lora_info

invokeai/app/invocations/sd3_text_encoder.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@
1616
from invokeai.app.invocations.model import CLIPField, T5EncoderField
1717
from invokeai.app.invocations.primitives import SD3ConditioningOutput
1818
from invokeai.app.services.shared.invocation_context import InvocationContext
19-
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX
20-
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
21-
from invokeai.backend.lora.lora_patcher import LoRAPatcher
2219
from invokeai.backend.model_manager.config import ModelFormat
20+
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX
21+
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
22+
from invokeai.backend.patches.model_patcher import LayerPatcher
2323
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, SD3ConditioningInfo
2424

2525
# The SD3 T5 Max Sequence Length set based on the default in diffusers.
@@ -150,7 +150,7 @@ def _clip_encode(
150150
if clip_text_encoder_config.format in [ModelFormat.Diffusers]:
151151
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
152152
exit_stack.enter_context(
153-
LoRAPatcher.apply_lora_patches(
153+
LayerPatcher.apply_model_patches(
154154
model=clip_text_encoder,
155155
patches=self._clip_lora_iterator(context, clip_model),
156156
prefix=FLUX_LORA_CLIP_PREFIX,
@@ -193,9 +193,9 @@ def _clip_encode(
193193

194194
def _clip_lora_iterator(
195195
self, context: InvocationContext, clip_model: CLIPField
196-
) -> Iterator[Tuple[LoRAModelRaw, float]]:
196+
) -> Iterator[Tuple[ModelPatchRaw, float]]:
197197
for lora in clip_model.loras:
198198
lora_info = context.models.load(lora.lora)
199-
assert isinstance(lora_info.model, LoRAModelRaw)
199+
assert isinstance(lora_info.model, ModelPatchRaw)
200200
yield (lora_info.model, lora.weight)
201201
del lora_info

invokeai/app/invocations/tiled_multi_diffusion_denoise_latents.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222
from invokeai.app.invocations.model import UNetField
2323
from invokeai.app.invocations.primitives import LatentsOutput
2424
from invokeai.app.services.shared.invocation_context import InvocationContext
25-
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
26-
from invokeai.backend.lora.lora_patcher import LoRAPatcher
25+
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
26+
from invokeai.backend.patches.model_patcher import LayerPatcher
2727
from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData, PipelineIntermediateState
2828
from invokeai.backend.stable_diffusion.multi_diffusion_pipeline import (
2929
MultiDiffusionPipeline,
@@ -194,10 +194,10 @@ def step_callback(state: PipelineIntermediateState) -> None:
194194
context.util.sd_step_callback(state, unet_config.base)
195195

196196
# Prepare an iterator that yields the UNet's LoRA models and their weights.
197-
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
197+
def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
198198
for lora in self.unet.loras:
199199
lora_info = context.models.load(lora.lora)
200-
assert isinstance(lora_info.model, LoRAModelRaw)
200+
assert isinstance(lora_info.model, ModelPatchRaw)
201201
yield (lora_info.model, lora.weight)
202202
del lora_info
203203

@@ -207,7 +207,7 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
207207
with (
208208
ExitStack() as exit_stack,
209209
unet_info as unet,
210-
LoRAPatcher.apply_lora_patches(model=unet, patches=_lora_loader(), prefix="lora_unet_"),
210+
LayerPatcher.apply_model_patches(model=unet, patches=_lora_loader(), prefix="lora_unet_"),
211211
):
212212
assert isinstance(unet, UNet2DConditionModel)
213213
latents = latents.to(device=unet.device, dtype=unet.dtype)

invokeai/backend/lora/layers/any_lora_layer.py

Lines changed: 0 additions & 14 deletions
This file was deleted.

invokeai/backend/lora/layers/set_parameter_layer.py

Lines changed: 0 additions & 29 deletions
This file was deleted.

invokeai/backend/lora/sidecar_layers/concatenated_lora/__init__.py

Whitespace-only changes.

invokeai/backend/lora/sidecar_layers/concatenated_lora/concatenated_lora_linear_sidecar_layer.py

Lines changed: 0 additions & 34 deletions
This file was deleted.

invokeai/backend/lora/sidecar_layers/lora/__init__.py

Whitespace-only changes.

invokeai/backend/lora/sidecar_layers/lora/lora_linear_sidecar_layer.py

Lines changed: 0 additions & 27 deletions
This file was deleted.

invokeai/backend/lora/sidecar_layers/lora_sidecar_layer.py

Whitespace-only changes.

invokeai/backend/lora/sidecar_layers/lora_sidecar_module.py

Lines changed: 0 additions & 24 deletions
This file was deleted.

0 commit comments

Comments
 (0)