Skip to content

Commit b46d7ab

Browse files
authored
Partial Loading PR3: Integrate 1) partial loading, 2) quantized models, 3) model patching (#7500)
## Summary This PR is the third in a sequence of PRs working towards support for partial loading of models onto the compute device (for low-VRAM operation). This PR updates the LoRA patching code so that the following features can cooperate fully: - Partial loading of weights onto the GPU - Quantized layers / weights - Model patches (e.g. LoRA) Note that this PR does not yet enable partial loading. It adds support in the model patching code so that partial loading can be enabled in a future PR. ## Technical Design Decisions The layer patching logic has been integrated into the custom layers (via `CustomModuleMixin`) rather than keeping it in a separate set of wrapper layers, as before. This has the following advantages: - It makes it easier to calculate the modified weights on the fly and then reuse the normal forward() logic. - In the future, it makes it possible to pass original parameters that have been cast to the device down to the LoRA calculation without having to re-cast (but the current implementation hasn't fully taken advantage of this yet). ## Know Limitations 1. I haven't fully solved device management for patch types that require the original layer value to calculate the patch. These aren't very common, and are not compatible with some quantized layers, so leaving this for future if there's demand. 2. There is a small speed regression for models that have CPU bottlenecks. This seems to be caused by slightly slower method resolution on the custom layers sub-classes. The regression does not show up on larger models, like FLUX, that are almost entirely GPU-limited. I think this small regression is tolerable, but if we decide that it's not, then the slowdown can easily be reclaimed by optimizing other CPU operations (e.g. if we only sent every 2nd progress image, we'd see a much more significant speedup). ## Related Issues / Discussions - #7492 - #7494 ## QA Instructions Speed tests: - Vanilla SD1 speed regression - Before: 3.156s (8.78 it/s) - After: 3.54s (8.35 it/s) - Vanilla SDXL speed regression - Before: 6.23s (4.46 it/s) - After: 6.45s (4.31 it/s) - Vanilla FLUX speed regression - Before: 12.02s (2.27 it/s) - After: 11.91s (2.29 it/s) LoRA tests with default configuration: - [x] SD1: A handful of LoRA variants - [x] SDXL: A handful of LoRA variants - [x] flux non-quantized: multiple lora variants - [x] flux bnb-quantized: multiple lora variants - [x] flux ggml-quantized: muliple lora variants - [x] flux non-quantized: FLUX control LoRA - [x] flux bnb-quantized: FLUX control LoRA - [x] flux ggml-quantized: FLUX control LoRA LoRA tests with sidecar patching forced: - [x] SD1: A handful of LoRA variants - [x] SDXL: A handful of LoRA variants - [x] flux non-quantized: multiple lora variants - [x] flux bnb-quantized: multiple lora variants - [x] flux ggml-quantized: muliple lora variants - [x] flux non-quantized: FLUX control LoRA - [x] flux bnb-quantized: FLUX control LoRA - [x] flux ggml-quantized: FLUX control LoRA Other: - [x] Smoke testing of IP-Adapter, ControlNet All tests repeated on: - [x] cuda - [x] cpu (only test SD1, because larger models are prohibitively slow) - [x] mps (skipped FLUX tests, because my Mac doesn't have enough memory to run them in a reasonable amount of time) ## Merge Plan No special instructions. ## Checklist - [x] _The PR has a short but descriptive title, suitable for a changelog_ - [x] _Tests added / updated (if applicable)_ - [x] _Documentation added / updated (if applicable)_ - [ ] _Updated `What's New` copy (if doing a release after this PR)_
2 parents 6bf5b74 + 9a0a226 commit b46d7ab

File tree

50 files changed

+1732
-1033
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+1732
-1033
lines changed

invokeai/app/invocations/compel.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
from invokeai.app.services.shared.invocation_context import InvocationContext
2121
from invokeai.app.util.ti_utils import generate_ti_list
2222
from invokeai.backend.model_patcher import ModelPatcher
23+
from invokeai.backend.patches.layer_patcher import LayerPatcher
2324
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
24-
from invokeai.backend.patches.model_patcher import LayerPatcher
2525
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
2626
BasicConditioningInfo,
2727
ConditioningFieldData,
@@ -82,10 +82,11 @@ def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
8282
# apply all patches while the model is on the target device
8383
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
8484
tokenizer_info as tokenizer,
85-
LayerPatcher.apply_model_patches(
85+
LayerPatcher.apply_smart_model_patches(
8686
model=text_encoder,
8787
patches=_lora_loader(),
8888
prefix="lora_te_",
89+
dtype=text_encoder.dtype,
8990
cached_weights=cached_weights,
9091
),
9192
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
@@ -179,10 +180,11 @@ def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
179180
# apply all patches while the model is on the target device
180181
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
181182
tokenizer_info as tokenizer,
182-
LayerPatcher.apply_model_patches(
183-
text_encoder,
183+
LayerPatcher.apply_smart_model_patches(
184+
model=text_encoder,
184185
patches=_lora_loader(),
185186
prefix=lora_prefix,
187+
dtype=text_encoder.dtype,
186188
cached_weights=cached_weights,
187189
),
188190
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.

invokeai/app/invocations/denoise_latents.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@
3939
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
4040
from invokeai.backend.model_manager import BaseModelType, ModelVariantType
4141
from invokeai.backend.model_patcher import ModelPatcher
42+
from invokeai.backend.patches.layer_patcher import LayerPatcher
4243
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
43-
from invokeai.backend.patches.model_patcher import LayerPatcher
4444
from invokeai.backend.stable_diffusion import PipelineIntermediateState
4545
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, DenoiseInputs
4646
from invokeai.backend.stable_diffusion.diffusers_pipeline import (
@@ -1003,10 +1003,11 @@ def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
10031003
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
10041004
SeamlessExt.static_patch_model(unet, self.unet.seamless_axes), # FIXME
10051005
# Apply the LoRA after unet has been moved to its target device for faster patching.
1006-
LayerPatcher.apply_model_patches(
1006+
LayerPatcher.apply_smart_model_patches(
10071007
model=unet,
10081008
patches=_lora_loader(),
10091009
prefix="lora_unet_",
1010+
dtype=unet.dtype,
10101011
cached_weights=cached_weights,
10111012
),
10121013
):

invokeai/app/invocations/flux_denoise.py

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,9 @@
4848
)
4949
from invokeai.backend.flux.text_conditioning import FluxTextConditioning
5050
from invokeai.backend.model_manager.config import ModelFormat
51+
from invokeai.backend.patches.layer_patcher import LayerPatcher
5152
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
5253
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
53-
from invokeai.backend.patches.model_patcher import LayerPatcher
5454
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
5555
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
5656
from invokeai.backend.util.devices import TorchDevice
@@ -304,36 +304,33 @@ def _run_diffusion(
304304
config = transformer_info.config
305305
assert config is not None
306306

307-
# Apply LoRA models to the transformer.
308-
# Note: We apply the LoRA after the transformer has been moved to its target device for faster patching.
307+
# Determine if the model is quantized.
308+
# If the model is quantized, then we need to apply the LoRA weights as sidecar layers. This results in
309+
# slower inference than direct patching, but is agnostic to the quantization format.
309310
if config.format in [ModelFormat.Checkpoint]:
310-
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
311-
exit_stack.enter_context(
312-
LayerPatcher.apply_model_patches(
313-
model=transformer,
314-
patches=self._lora_iterator(context),
315-
prefix=FLUX_LORA_TRANSFORMER_PREFIX,
316-
cached_weights=cached_weights,
317-
)
318-
)
311+
model_is_quantized = False
319312
elif config.format in [
320313
ModelFormat.BnbQuantizedLlmInt8b,
321314
ModelFormat.BnbQuantizednf4b,
322315
ModelFormat.GGUFQuantized,
323316
]:
324-
# The model is quantized, so apply the LoRA weights as sidecar layers. This results in slower inference,
325-
# than directly patching the weights, but is agnostic to the quantization format.
326-
exit_stack.enter_context(
327-
LayerPatcher.apply_model_sidecar_patches(
328-
model=transformer,
329-
patches=self._lora_iterator(context),
330-
prefix=FLUX_LORA_TRANSFORMER_PREFIX,
331-
dtype=inference_dtype,
332-
)
333-
)
317+
model_is_quantized = True
334318
else:
335319
raise ValueError(f"Unsupported model format: {config.format}")
336320

321+
# Apply LoRA models to the transformer.
322+
# Note: We apply the LoRA after the transformer has been moved to its target device for faster patching.
323+
exit_stack.enter_context(
324+
LayerPatcher.apply_smart_model_patches(
325+
model=transformer,
326+
patches=self._lora_iterator(context),
327+
prefix=FLUX_LORA_TRANSFORMER_PREFIX,
328+
dtype=inference_dtype,
329+
cached_weights=cached_weights,
330+
force_sidecar_patching=model_is_quantized,
331+
)
332+
)
333+
337334
# Prepare IP-Adapter extensions.
338335
pos_ip_adapter_extensions, neg_ip_adapter_extensions = self._prep_ip_adapter_extensions(
339336
pos_image_prompt_clip_embeds=pos_image_prompt_clip_embeds,

invokeai/app/invocations/flux_text_encoder.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818
from invokeai.app.services.shared.invocation_context import InvocationContext
1919
from invokeai.backend.flux.modules.conditioner import HFEncoder
2020
from invokeai.backend.model_manager.config import ModelFormat
21+
from invokeai.backend.patches.layer_patcher import LayerPatcher
2122
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX
2223
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
23-
from invokeai.backend.patches.model_patcher import LayerPatcher
2424
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo
2525

2626

@@ -111,10 +111,11 @@ def _clip_encode(self, context: InvocationContext) -> torch.Tensor:
111111
if clip_text_encoder_config.format in [ModelFormat.Diffusers]:
112112
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
113113
exit_stack.enter_context(
114-
LayerPatcher.apply_model_patches(
114+
LayerPatcher.apply_smart_model_patches(
115115
model=clip_text_encoder,
116116
patches=self._clip_lora_iterator(context),
117117
prefix=FLUX_LORA_CLIP_PREFIX,
118+
dtype=clip_text_encoder.dtype,
118119
cached_weights=cached_weights,
119120
)
120121
)

invokeai/app/invocations/sd3_text_encoder.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@
1717
from invokeai.app.invocations.primitives import SD3ConditioningOutput
1818
from invokeai.app.services.shared.invocation_context import InvocationContext
1919
from invokeai.backend.model_manager.config import ModelFormat
20+
from invokeai.backend.patches.layer_patcher import LayerPatcher
2021
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX
2122
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
22-
from invokeai.backend.patches.model_patcher import LayerPatcher
2323
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, SD3ConditioningInfo
2424

2525
# The SD3 T5 Max Sequence Length set based on the default in diffusers.
@@ -150,10 +150,11 @@ def _clip_encode(
150150
if clip_text_encoder_config.format in [ModelFormat.Diffusers]:
151151
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
152152
exit_stack.enter_context(
153-
LayerPatcher.apply_model_patches(
153+
LayerPatcher.apply_smart_model_patches(
154154
model=clip_text_encoder,
155155
patches=self._clip_lora_iterator(context, clip_model),
156156
prefix=FLUX_LORA_CLIP_PREFIX,
157+
dtype=clip_text_encoder.dtype,
157158
cached_weights=cached_weights,
158159
)
159160
)

invokeai/app/invocations/tiled_multi_diffusion_denoise_latents.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222
from invokeai.app.invocations.model import UNetField
2323
from invokeai.app.invocations.primitives import LatentsOutput
2424
from invokeai.app.services.shared.invocation_context import InvocationContext
25+
from invokeai.backend.patches.layer_patcher import LayerPatcher
2526
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
26-
from invokeai.backend.patches.model_patcher import LayerPatcher
2727
from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData, PipelineIntermediateState
2828
from invokeai.backend.stable_diffusion.multi_diffusion_pipeline import (
2929
MultiDiffusionPipeline,
@@ -207,7 +207,9 @@ def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
207207
with (
208208
ExitStack() as exit_stack,
209209
unet_info as unet,
210-
LayerPatcher.apply_model_patches(model=unet, patches=_lora_loader(), prefix="lora_unet_"),
210+
LayerPatcher.apply_smart_model_patches(
211+
model=unet, patches=_lora_loader(), prefix="lora_unet_", dtype=unet.dtype
212+
),
211213
):
212214
assert isinstance(unet, UNet2DConditionModel)
213215
latents = latents.to(device=unet.device, dtype=unet.dtype)

invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_with_partial_load.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
import torch
22

3-
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch_module_autocast import (
4-
AUTOCAST_MODULE_TYPE_MAPPING,
5-
apply_custom_layers_to_model,
6-
remove_custom_layers_from_model,
3+
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
4+
CustomModuleMixin,
75
)
86
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
97
from invokeai.backend.util.logging import InvokeAILogger
@@ -45,10 +43,10 @@ def __init__(self, model: torch.nn.Module, compute_device: torch.device):
4543

4644
def _find_modules_that_support_autocast(self) -> dict[str, torch.nn.Module]:
4745
"""Find all modules that support autocasting."""
48-
return {n: m for n, m in self._model.named_modules() if type(m) in AUTOCAST_MODULE_TYPE_MAPPING}
46+
return {n: m for n, m in self._model.named_modules() if isinstance(m, CustomModuleMixin)} # type: ignore
4947

5048
def _find_keys_in_modules_that_do_not_support_autocast(self) -> set[str]:
51-
keys_in_modules_that_do_not_support_autocast = set()
49+
keys_in_modules_that_do_not_support_autocast: set[str] = set()
5250
for key in self._cpu_state_dict.keys():
5351
for module_name in self._modules_that_support_autocast.keys():
5452
if key.startswith(module_name):
@@ -70,6 +68,11 @@ def _move_non_persistent_buffers_to_device(self, device: torch.device):
7068
if name in module._non_persistent_buffers_set:
7169
module._buffers[name] = buffer.to(device, copy=True)
7270

71+
def _set_autocast_enabled_in_all_modules(self, enabled: bool):
72+
"""Set autocast_enabled flag in all modules that support device autocasting."""
73+
for module in self._modules_that_support_autocast.values():
74+
module.set_device_autocasting_enabled(enabled)
75+
7376
@property
7477
def model(self) -> torch.nn.Module:
7578
return self._model
@@ -114,7 +117,7 @@ def partial_load_to_vram(self, vram_bytes_to_load: int) -> int:
114117

115118
cur_state_dict = self._model.state_dict()
116119

117-
# First, process the keys *must* be loaded into VRAM.
120+
# First, process the keys that *must* be loaded into VRAM.
118121
for key in self._keys_in_modules_that_do_not_support_autocast:
119122
param = cur_state_dict[key]
120123
if param.device.type == self._compute_device.type:
@@ -157,10 +160,10 @@ def partial_load_to_vram(self, vram_bytes_to_load: int) -> int:
157160
self._cur_vram_bytes += vram_bytes_loaded
158161

159162
if fully_loaded:
160-
remove_custom_layers_from_model(self._model)
163+
self._set_autocast_enabled_in_all_modules(False)
161164
# TODO(ryand): Warn if the self.cur_vram_bytes() and self.total_bytes() are out of sync.
162165
else:
163-
apply_custom_layers_to_model(self._model)
166+
self._set_autocast_enabled_in_all_modules(True)
164167

165168
# Move all non-persistent buffers to the compute device. These are a weird edge case and do not participate in
166169
# the vram_bytes_loaded tracking.
@@ -197,5 +200,5 @@ def partial_unload_from_vram(self, vram_bytes_to_free: int) -> int:
197200

198201
# We may have gone from a fully-loaded model to a partially-loaded model, so we need to reapply the custom
199202
# layers.
200-
apply_custom_layers_to_model(self._model)
203+
self._set_autocast_enabled_in_all_modules(True)
201204
return vram_bytes_freed

invokeai/backend/model_manager/load/model_cache/model_cache.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
1414
from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord
1515
from invokeai.backend.model_manager.load.model_cache.cache_stats import CacheStats
16+
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch_module_autocast import (
17+
apply_custom_layers_to_model,
18+
)
1619
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data
1720
from invokeai.backend.util.devices import TorchDevice
1821
from invokeai.backend.util.logging import InvokeAILogger
@@ -143,6 +146,10 @@ def put(
143146
size = calc_model_size_by_data(self._logger, model)
144147
self.make_room(size)
145148

149+
# Inject custom modules into the model.
150+
if isinstance(model, torch.nn.Module):
151+
apply_custom_layers_to_model(model)
152+
146153
running_on_cpu = self._execution_device == torch.device("cpu")
147154
state_dict = model.state_dict() if isinstance(model, torch.nn.Module) and not running_on_cpu else None
148155
cache_record = CacheRecord(key=key, model=model, device=self._storage_device, state_dict=state_dict, size=size)

invokeai/backend/model_manager/load/model_cache/torch_module_autocast/autocast_modules.py

Lines changed: 0 additions & 50 deletions
This file was deleted.
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
2+
This directory contains custom implementations of common torch.nn.Module classes that add support for:
3+
- Streaming weights to the execution device
4+
- Applying sidecar patches at execution time (e.g. sidecar LoRA layers)
5+
6+
Each custom class sub-classes the original module type that is is replacing, so the following properties are preserved:
7+
- `isinstance(m, torch.nn.OrginalModule)` should still work.
8+
- Patching the weights directly (e.g. for LoRA) should still work. (Of course, this is not possible for quantized layers, hence the sidecar support.)
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import torch
2+
3+
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
4+
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
5+
CustomModuleMixin,
6+
)
7+
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.utils import (
8+
add_nullable_tensors,
9+
)
10+
11+
12+
class CustomConv1d(torch.nn.Conv1d, CustomModuleMixin):
13+
def _autocast_forward_with_patches(self, input: torch.Tensor) -> torch.Tensor:
14+
weight = cast_to_device(self.weight, input.device)
15+
bias = cast_to_device(self.bias, input.device)
16+
17+
# Prepare the original parameters for the patch aggregation.
18+
orig_params = {"weight": weight, "bias": bias}
19+
# Filter out None values.
20+
orig_params = {k: v for k, v in orig_params.items() if v is not None}
21+
22+
aggregated_param_residuals = self._aggregate_patch_parameters(
23+
patches_and_weights=self._patches_and_weights,
24+
orig_params=orig_params,
25+
device=input.device,
26+
)
27+
28+
weight = add_nullable_tensors(weight, aggregated_param_residuals.get("weight", None))
29+
bias = add_nullable_tensors(bias, aggregated_param_residuals.get("bias", None))
30+
return self._conv_forward(input, weight, bias)
31+
32+
def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor:
33+
weight = cast_to_device(self.weight, input.device)
34+
bias = cast_to_device(self.bias, input.device)
35+
return self._conv_forward(input, weight, bias)
36+
37+
def forward(self, input: torch.Tensor) -> torch.Tensor:
38+
if len(self._patches_and_weights) > 0:
39+
return self._autocast_forward_with_patches(input)
40+
elif self._device_autocasting_enabled:
41+
return self._autocast_forward(input)
42+
else:
43+
return super().forward(input)

0 commit comments

Comments
 (0)