Skip to content

LoRA refactor to enable FLUX control LoRAs w/ quantized tranformers #7446

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
42f8d6a
Rename backend/lora/ to backend/patches
RyanJDick Dec 13, 2024
41664f8
Rename backend/patches/conversions/ to backend/patches/lora_conversions/
RyanJDick Dec 13, 2024
693d426
Add basic unit tests for LoRALayer.
RyanJDick Dec 13, 2024
8ea697d
Mark LoRALayerBase.rank(...) as a private method.
RyanJDick Dec 13, 2024
1eede43
Delete ONNXModelPatcher. It is outdated and hasn't been used for a lo…
RyanJDick Dec 13, 2024
58de93a
Delete empty file.
RyanJDick Dec 13, 2024
2b441d6
Add BaseLayerPatch ABC to clarify the intended patch interface.
RyanJDick Dec 13, 2024
808e377
Remove AnyLoRALayer type definition in favor of using BaseLayerPatch …
RyanJDick Dec 13, 2024
3a8a544
Add basic unit tests for SetParameterLayer.
RyanJDick Dec 13, 2024
443d838
Add initial basic implementation of sidecar wrappers.
RyanJDick Dec 13, 2024
e2451ef
A unit tests for LinearSidecarWrapper (and fix a bug).
RyanJDick Dec 13, 2024
1e0552c
Add optimized implementations for the LinearSidecarWrapper when using…
RyanJDick Dec 13, 2024
ac28370
Break up functions in LoRAPatcher in preparation for more refactoring.
RyanJDick Dec 13, 2024
46133b5
Switch LoRAPatcher to use the new sidecar_wrappers/ rather than sidec…
RyanJDick Dec 13, 2024
c76a448
Delete old sidecar_layers/ dir.
RyanJDick Dec 13, 2024
606d58d
Add sidecar wrapper for FLUX RMSNorm layers to support SetParameterLa…
RyanJDick Dec 13, 2024
e7e3f7e
Ensure that patches are on the correct device when used in sidecar wr…
RyanJDick Dec 13, 2024
fe09f2d
Move handling of LoRA scale and patch weight down into the layer patc…
RyanJDick Dec 13, 2024
37e3089
Push LoRA layer reshaping down into the patch layers and add a new Fl…
RyanJDick Dec 14, 2024
80f64ab
Use a FluxControlLoRALayer when loading FLUX control LoRAs.
RyanJDick Dec 14, 2024
9369b39
Add GGMLTensor op.
RyanJDick Dec 14, 2024
c604a09
Rename LoRAPatcher -> ModelPatcher.
RyanJDick Dec 14, 2024
b820862
Rename ModelPatcher methods to reflect that they are general model pa…
RyanJDick Dec 14, 2024
7fad4c9
Rename LoRAModelRaw to ModelPatchRaw.
RyanJDick Dec 14, 2024
dd09509
Rename ModelPatcher -> LayerPatcher to avoid conflicts with another M…
RyanJDick Dec 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions invokeai/app/invocations/compel.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
from invokeai.app.invocations.primitives import ConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.ti_utils import generate_ti_list
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.lora_patcher import LoRAPatcher
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.patches.model_patcher import LayerPatcher
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
BasicConditioningInfo,
ConditioningFieldData,
Expand Down Expand Up @@ -66,10 +66,10 @@ def invoke(self, context: InvocationContext) -> ConditioningOutput:
tokenizer_info = context.models.load(self.clip.tokenizer)
text_encoder_info = context.models.load(self.clip.text_encoder)

def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
for lora in self.clip.loras:
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, LoRAModelRaw)
assert isinstance(lora_info.model, ModelPatchRaw)
yield (lora_info.model, lora.weight)
del lora_info
return
Expand All @@ -82,7 +82,7 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
# apply all patches while the model is on the target device
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
tokenizer_info as tokenizer,
LoRAPatcher.apply_lora_patches(
LayerPatcher.apply_model_patches(
model=text_encoder,
patches=_lora_loader(),
prefix="lora_te_",
Expand Down Expand Up @@ -162,11 +162,11 @@ def run_clip_compel(
c_pooled = None
return c, c_pooled

def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
for lora in clip_field.loras:
lora_info = context.models.load(lora.lora)
lora_model = lora_info.model
assert isinstance(lora_model, LoRAModelRaw)
assert isinstance(lora_model, ModelPatchRaw)
yield (lora_model, lora.weight)
del lora_info
return
Expand All @@ -179,7 +179,7 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
# apply all patches while the model is on the target device
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
tokenizer_info as tokenizer,
LoRAPatcher.apply_lora_patches(
LayerPatcher.apply_model_patches(
text_encoder,
patches=_lora_loader(),
prefix=lora_prefix,
Expand Down
10 changes: 5 additions & 5 deletions invokeai/app/invocations/denoise_latents.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import prepare_control_image
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.lora_patcher import LoRAPatcher
from invokeai.backend.model_manager import BaseModelType, ModelVariantType
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.patches.model_patcher import LayerPatcher
from invokeai.backend.stable_diffusion import PipelineIntermediateState
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, DenoiseInputs
from invokeai.backend.stable_diffusion.diffusers_pipeline import (
Expand Down Expand Up @@ -987,10 +987,10 @@ def _old_invoke(self, context: InvocationContext) -> LatentsOutput:
def step_callback(state: PipelineIntermediateState) -> None:
context.util.sd_step_callback(state, unet_config.base)

def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
for lora in self.unet.loras:
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, LoRAModelRaw)
assert isinstance(lora_info.model, ModelPatchRaw)
yield (lora_info.model, lora.weight)
del lora_info
return
Expand All @@ -1003,7 +1003,7 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
SeamlessExt.static_patch_model(unet, self.unet.seamless_axes), # FIXME
# Apply the LoRA after unet has been moved to its target device for faster patching.
LoRAPatcher.apply_lora_patches(
LayerPatcher.apply_model_patches(
model=unet,
patches=_lora_loader(),
prefix="lora_unet_",
Expand Down
14 changes: 7 additions & 7 deletions invokeai/app/invocations/flux_denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@
unpack,
)
from invokeai.backend.flux.text_conditioning import FluxTextConditioning
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.lora_patcher import LoRAPatcher
from invokeai.backend.model_manager.config import ModelFormat
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.patches.model_patcher import LayerPatcher
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
from invokeai.backend.util.devices import TorchDevice
Expand Down Expand Up @@ -306,7 +306,7 @@ def _run_diffusion(
if config.format in [ModelFormat.Checkpoint]:
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
exit_stack.enter_context(
LoRAPatcher.apply_lora_patches(
LayerPatcher.apply_model_patches(
model=transformer,
patches=self._lora_iterator(context),
prefix=FLUX_LORA_TRANSFORMER_PREFIX,
Expand All @@ -321,7 +321,7 @@ def _run_diffusion(
# The model is quantized, so apply the LoRA weights as sidecar layers. This results in slower inference,
# than directly patching the weights, but is agnostic to the quantization format.
exit_stack.enter_context(
LoRAPatcher.apply_lora_sidecar_patches(
LayerPatcher.apply_model_sidecar_patches(
model=transformer,
patches=self._lora_iterator(context),
prefix=FLUX_LORA_TRANSFORMER_PREFIX,
Expand Down Expand Up @@ -715,15 +715,15 @@ def _prep_ip_adapter_extensions(

return pos_ip_adapter_extensions, neg_ip_adapter_extensions

def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]:
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[ModelPatchRaw, float]]:
loras: list[Union[LoRAField, ControlLoRAField]] = [*self.transformer.loras]
if self.control_lora:
# Note: Since FLUX structural control LoRAs modify the shape of some weights, it is important that they are
# applied last.
loras.append(self.control_lora)
for lora in loras:
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, LoRAModelRaw)
assert isinstance(lora_info.model, ModelPatchRaw)
yield (lora_info.model, lora.weight)
del lora_info

Expand Down
12 changes: 6 additions & 6 deletions invokeai/app/invocations/flux_text_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
from invokeai.app.invocations.primitives import FluxConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.modules.conditioner import HFEncoder
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.lora_patcher import LoRAPatcher
from invokeai.backend.model_manager.config import ModelFormat
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.patches.model_patcher import LayerPatcher
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo


Expand Down Expand Up @@ -111,7 +111,7 @@ def _clip_encode(self, context: InvocationContext) -> torch.Tensor:
if clip_text_encoder_config.format in [ModelFormat.Diffusers]:
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
exit_stack.enter_context(
LoRAPatcher.apply_lora_patches(
LayerPatcher.apply_model_patches(
model=clip_text_encoder,
patches=self._clip_lora_iterator(context),
prefix=FLUX_LORA_CLIP_PREFIX,
Expand All @@ -130,9 +130,9 @@ def _clip_encode(self, context: InvocationContext) -> torch.Tensor:
assert isinstance(pooled_prompt_embeds, torch.Tensor)
return pooled_prompt_embeds

def _clip_lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]:
def _clip_lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[ModelPatchRaw, float]]:
for lora in self.clip.loras:
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, LoRAModelRaw)
assert isinstance(lora_info.model, ModelPatchRaw)
yield (lora_info.model, lora.weight)
del lora_info
12 changes: 6 additions & 6 deletions invokeai/app/invocations/sd3_text_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
from invokeai.app.invocations.model import CLIPField, T5EncoderField
from invokeai.app.invocations.primitives import SD3ConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.lora_patcher import LoRAPatcher
from invokeai.backend.model_manager.config import ModelFormat
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.patches.model_patcher import LayerPatcher
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, SD3ConditioningInfo

# The SD3 T5 Max Sequence Length set based on the default in diffusers.
Expand Down Expand Up @@ -150,7 +150,7 @@ def _clip_encode(
if clip_text_encoder_config.format in [ModelFormat.Diffusers]:
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
exit_stack.enter_context(
LoRAPatcher.apply_lora_patches(
LayerPatcher.apply_model_patches(
model=clip_text_encoder,
patches=self._clip_lora_iterator(context, clip_model),
prefix=FLUX_LORA_CLIP_PREFIX,
Expand Down Expand Up @@ -193,9 +193,9 @@ def _clip_encode(

def _clip_lora_iterator(
self, context: InvocationContext, clip_model: CLIPField
) -> Iterator[Tuple[LoRAModelRaw, float]]:
) -> Iterator[Tuple[ModelPatchRaw, float]]:
for lora in clip_model.loras:
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, LoRAModelRaw)
assert isinstance(lora_info.model, ModelPatchRaw)
yield (lora_info.model, lora.weight)
del lora_info
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
from invokeai.app.invocations.model import UNetField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.lora_patcher import LoRAPatcher
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.patches.model_patcher import LayerPatcher
from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData, PipelineIntermediateState
from invokeai.backend.stable_diffusion.multi_diffusion_pipeline import (
MultiDiffusionPipeline,
Expand Down Expand Up @@ -194,10 +194,10 @@ def step_callback(state: PipelineIntermediateState) -> None:
context.util.sd_step_callback(state, unet_config.base)

# Prepare an iterator that yields the UNet's LoRA models and their weights.
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
for lora in self.unet.loras:
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, LoRAModelRaw)
assert isinstance(lora_info.model, ModelPatchRaw)
yield (lora_info.model, lora.weight)
del lora_info

Expand All @@ -207,7 +207,7 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
with (
ExitStack() as exit_stack,
unet_info as unet,
LoRAPatcher.apply_lora_patches(model=unet, patches=_lora_loader(), prefix="lora_unet_"),
LayerPatcher.apply_model_patches(model=unet, patches=_lora_loader(), prefix="lora_unet_"),
):
assert isinstance(unet, UNet2DConditionModel)
latents = latents.to(device=unet.device, dtype=unet.dtype)
Expand Down
14 changes: 0 additions & 14 deletions invokeai/backend/lora/layers/any_lora_layer.py

This file was deleted.

29 changes: 0 additions & 29 deletions invokeai/backend/lora/layers/set_parameter_layer.py

This file was deleted.

Empty file.

This file was deleted.

Empty file.

This file was deleted.

Empty file.
24 changes: 0 additions & 24 deletions invokeai/backend/lora/sidecar_layers/lora_sidecar_module.py

This file was deleted.

Loading
Loading