Skip to content

Commit 87fdcb7

Browse files
authored
Partial Loading PR4: Enable partial loading (behind config flag) (#7505)
## Summary This PR adds support for partial loading of models onto the GPU. This enables models to run with much lower peak VRAM requirements (e.g. full FLUX dev with 8GB of VRAM). The partial loading feature is enabled behind a new config flag: `enable_partial_loading=True`. This flag defaults to `False`. **Note about performance:** The `ram` and `vram` config limits are still applied when `enable_partial_loading=True` is set. This can result in significant slowdowns compared to the 'old' behaviour. Consider the case where the VRAM limit is set to `vram=0.75` (GB) and we are trying to run an 8GB model. When `enable_partial_loading=False`, we attempt to load the entire model into VRAM, and if it fits (no OOM error) then it will run at full speed. When `enable_partial_loading=True`, since we have the option to partially load the model we will only load 0.75 GB into VRAM and leave the remaining 7.25 GB in RAM. This will cause inference to be much slower than before. To workaround this, it is important that your `ram` and `vram` configs are carefully tuned. In a future PR, we will add the ability to dynamically set the RAM/VRAM limits based on the available memory / VRAM. ## Related Issues / Discussions - #7492 - #7494 - #7500 ## QA Instructions Tests with `enable_partial_loading=True`, `vram=2`, on CUDA device: For all tests, we expect model memory to stay below 2 GB. Peak working memory will be higher. - [x] SD1 inference - [x] SDXL inference - [x] FLUX non-quantized inference - [x] FLUX GGML-quantized inference - [x] FLUX BnB quantized inference - [x] Variety of ControlNet / IP-Adapter / LoRA smoke tests Tests with `enable_partial_loading=True`, and hack to force all models to load 10%, on CUDA device: - [x] SD1 inference - [x] SDXL inference - [x] FLUX non-quantized inference - [x] FLUX GGML-quantized inference - [x] FLUX BnB quantized inference - [x] Variety of ControlNet / IP-Adapter / LoRA smoke tests Tests with `enable_partial_loading=False`, `vram=30`: We expect no change in behaviour when `enable_partial_loading=False`. - [x] SD1 inference - [x] SDXL inference - [x] FLUX non-quantized inference - [x] FLUX GGML-quantized inference - [x] FLUX BnB quantized inference - [x] Variety of ControlNet / IP-Adapter / LoRA smoke tests Other platforms: - [x] No change in behavior on MPS, even if `enable_partial_loading=True`. - [x] No change in behavior on CPU-only systems, even if `enable_partial_loading=True`. ## Merge Plan - [x] Merge #7500 first, and change the target branch to main ## Checklist - [x] _The PR has a short but descriptive title, suitable for a changelog_ - [x] _Tests added / updated (if applicable)_ - [x] _Documentation added / updated (if applicable)_ - [ ] _Updated `What's New` copy (if doing a release after this PR)_
2 parents 782ee7a + 6a9de1f commit 87fdcb7

File tree

23 files changed

+396
-292
lines changed

23 files changed

+396
-292
lines changed

invokeai/app/invocations/compel.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
105105
textual_inversion_manager=ti_manager,
106106
dtype_for_device_getter=TorchDevice.choose_torch_dtype,
107107
truncate_long_prompts=False,
108+
device=TorchDevice.choose_torch_device(),
108109
)
109110

110111
conjunction = Compel.parse_prompt_string(self.prompt)
@@ -207,6 +208,7 @@ def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
207208
truncate_long_prompts=False, # TODO:
208209
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip
209210
requires_pooled=get_pooled,
211+
device=TorchDevice.choose_torch_device(),
210212
)
211213

212214
conjunction = Compel.parse_prompt_string(prompt)

invokeai/app/invocations/latents_to_image.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
6262
with SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes), vae_info as vae:
6363
context.util.signal_progress("Running VAE decoder")
6464
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
65-
latents = latents.to(vae.device)
65+
latents = latents.to(TorchDevice.choose_torch_device())
6666
if self.fp32:
6767
vae.to(dtype=torch.float32)
6868

invokeai/app/invocations/sd3_latents_to_image.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
4949
with SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes), vae_info as vae:
5050
context.util.signal_progress("Running VAE")
5151
assert isinstance(vae, (AutoencoderKL))
52-
latents = latents.to(vae.device)
52+
latents = latents.to(TorchDevice.choose_torch_device())
5353

5454
vae.disable_tiling()
5555

invokeai/app/invocations/sd3_text_encoder.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX
2222
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
2323
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, SD3ConditioningInfo
24+
from invokeai.backend.util.devices import TorchDevice
2425

2526
# The SD3 T5 Max Sequence Length set based on the default in diffusers.
2627
SD3_T5_MAX_SEQ_LEN = 256
@@ -120,7 +121,7 @@ def _t5_encode(self, context: InvocationContext, max_seq_len: int) -> torch.Tens
120121
f" {max_seq_len} tokens: {removed_text}"
121122
)
122123

123-
prompt_embeds = t5_text_encoder(text_input_ids.to(t5_text_encoder.device))[0]
124+
prompt_embeds = t5_text_encoder(text_input_ids.to(TorchDevice.choose_torch_device()))[0]
124125

125126
assert isinstance(prompt_embeds, torch.Tensor)
126127
return prompt_embeds
@@ -185,7 +186,7 @@ def _clip_encode(
185186
f" {tokenizer_max_length} tokens: {removed_text}"
186187
)
187188
prompt_embeds = clip_text_encoder(
188-
input_ids=text_input_ids.to(clip_text_encoder.device), output_hidden_states=True
189+
input_ids=text_input_ids.to(TorchDevice.choose_torch_device()), output_hidden_states=True
189190
)
190191
pooled_prompt_embeds = prompt_embeds[0]
191192
prompt_embeds = prompt_embeds.hidden_states[-2]

invokeai/app/services/config/config_default.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ class InvokeAIAppConfig(BaseSettings):
107107
vram: Amount of VRAM reserved for model storage (GB).
108108
lazy_offload: Keep models in VRAM until their space is needed.
109109
log_memory_usage: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.
110+
enable_partial_loading: Enable partial loading of models. This enables models to run with reduced VRAM requirements (at the cost of slower speed) by streaming the model from RAM to VRAM as its used. Partial loading can cause models to run more slowly if they were previously being fully loaded into VRAM. If enabling this setting, make sure that your ram and vram cache limits are properly tuned.
110111
device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps`
111112
precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.<br>Valid values: `auto`, `float16`, `bfloat16`, `float32`
112113
sequential_guidance: Whether to calculate guidance in serial instead of in parallel, lowering memory requirements.
@@ -178,6 +179,7 @@ class InvokeAIAppConfig(BaseSettings):
178179
vram: float = Field(default=DEFAULT_VRAM_CACHE, ge=0, description="Amount of VRAM reserved for model storage (GB).")
179180
lazy_offload: bool = Field(default=True, description="Keep models in VRAM until their space is needed.")
180181
log_memory_usage: bool = Field(default=False, description="If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.")
182+
enable_partial_loading: bool = Field(default=False, description="Enable partial loading of models. This enables models to run with reduced VRAM requirements (at the cost of slower speed) by streaming the model from RAM to VRAM as its used. Partial loading can cause models to run more slowly if they were previously being fully loaded into VRAM. If enabling this setting, make sure that your ram and vram cache limits are properly tuned.")
181183

182184
# DEVICE
183185
device: DEVICE = Field(default="auto", description="Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.")

invokeai/app/services/model_manager/model_manager_default.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,9 @@ def build_model_manager(
8282
logger.setLevel(app_config.log_level.upper())
8383

8484
ram_cache = ModelCache(
85-
max_cache_size=app_config.ram,
86-
max_vram_cache_size=app_config.vram,
87-
lazy_offloading=app_config.lazy_offload,
85+
max_ram_cache_size_gb=app_config.ram,
86+
max_vram_cache_size_gb=app_config.vram,
87+
enable_partial_loading=app_config.enable_partial_loading,
8888
logger=logger,
8989
execution_device=execution_device or TorchDevice.choose_torch_device(),
9090
)

invokeai/backend/flux/modules/conditioner.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from torch import Tensor, nn
44
from transformers import PreTrainedModel, PreTrainedTokenizer
55

6+
from invokeai.backend.util.devices import TorchDevice
7+
68

79
class HFEncoder(nn.Module):
810
def __init__(self, encoder: PreTrainedModel, tokenizer: PreTrainedTokenizer, is_clip: bool, max_length: int):
@@ -26,7 +28,7 @@ def forward(self, text: list[str]) -> Tensor:
2628
)
2729

2830
outputs = self.hf_module(
29-
input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
31+
input_ids=batch_encoding["input_ids"].to(TorchDevice.choose_torch_device()),
3032
attention_mask=None,
3133
output_hidden_states=False,
3234
)

invokeai/backend/image_util/hed.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
resize_image_to_resolution,
1919
safe_step,
2020
)
21+
from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device
2122

2223

2324
class DoubleConvBlock(torch.nn.Module):
@@ -109,7 +110,7 @@ def run(
109110
Returns:
110111
The detected edges.
111112
"""
112-
device = next(iter(self.network.parameters())).device
113+
device = get_effective_device(self.network)
113114
np_image = pil_to_np(input_image)
114115
np_image = normalize_image_channel_count(np_image)
115116
np_image = resize_image_to_resolution(np_image, detect_resolution)
@@ -183,7 +184,7 @@ def run(self, image: Image.Image, safe: bool = False, scribble: bool = False) ->
183184
The detected edges.
184185
"""
185186

186-
device = next(iter(self.model.parameters())).device
187+
device = get_effective_device(self.model)
187188

188189
np_image = pil_to_np(image)
189190

invokeai/backend/image_util/infill_methods/lama.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import invokeai.backend.util.logging as logger
99
from invokeai.backend.model_manager.config import AnyModel
10+
from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device
1011

1112

1213
def norm_img(np_img):
@@ -31,7 +32,7 @@ def __call__(self, input_image: Image.Image, *args: Any, **kwds: Any) -> Any:
3132
mask = norm_img(mask)
3233
mask = (mask > 0) * 1
3334

34-
device = next(self._model.buffers()).device
35+
device = get_effective_device(self._model)
3536
image = torch.from_numpy(image).unsqueeze(0).to(device)
3637
mask = torch.from_numpy(mask).unsqueeze(0).to(device)
3738

invokeai/backend/image_util/lineart.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
pil_to_np,
1818
resize_image_to_resolution,
1919
)
20+
from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device
2021

2122

2223
class ResidualBlock(nn.Module):
@@ -130,7 +131,7 @@ def run(
130131
Returns:
131132
The detected lineart.
132133
"""
133-
device = next(iter(self.model.parameters())).device
134+
device = get_effective_device(self.model)
134135

135136
np_image = pil_to_np(input_image)
136137
np_image = normalize_image_channel_count(np_image)
@@ -201,7 +202,7 @@ def run(self, image: Image.Image) -> Image.Image:
201202
Returns:
202203
The detected edges.
203204
"""
204-
device = next(iter(self.model.parameters())).device
205+
device = get_effective_device(self.model)
205206

206207
np_image = pil_to_np(image)
207208

invokeai/backend/image_util/lineart_anime.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
pil_to_np,
2020
resize_image_to_resolution,
2121
)
22+
from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device
2223

2324

2425
class UnetGenerator(nn.Module):
@@ -171,7 +172,7 @@ def run(self, input_image: Image.Image, detect_resolution: int = 512, image_reso
171172
Returns:
172173
The detected lineart.
173174
"""
174-
device = next(iter(self.model.parameters())).device
175+
device = get_effective_device(self.model)
175176
np_image = pil_to_np(input_image)
176177

177178
np_image = normalize_image_channel_count(np_image)
@@ -239,7 +240,7 @@ def to(self, device: torch.device):
239240

240241
def run(self, image: Image.Image) -> Image.Image:
241242
"""Processes an image and returns the detected edges."""
242-
device = next(iter(self.model.parameters())).device
243+
device = get_effective_device(self.model)
243244

244245
np_image = pil_to_np(image)
245246

invokeai/backend/image_util/mlsd/utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
import torch
1515
from torch.nn import functional as F
1616

17+
from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device
18+
1719

1820
def deccode_output_score_and_ptss(tpMap, topk_n = 200, ksize = 5):
1921
'''
@@ -49,7 +51,7 @@ def pred_lines(image, model,
4951
dist_thr=20.0):
5052
h, w, _ = image.shape
5153

52-
device = next(iter(model.parameters())).device
54+
device = get_effective_device(model)
5355
h_ratio, w_ratio = [h / input_shape[0], w / input_shape[1]]
5456

5557
resized_image = np.concatenate([cv2.resize(image, (input_shape[1], input_shape[0]), interpolation=cv2.INTER_AREA),
@@ -108,7 +110,7 @@ def pred_squares(image,
108110
'''
109111
h, w, _ = image.shape
110112
original_shape = [h, w]
111-
device = next(iter(model.parameters())).device
113+
device = get_effective_device(model)
112114

113115
resized_image = np.concatenate([cv2.resize(image, (input_shape[0], input_shape[1]), interpolation=cv2.INTER_AREA),
114116
np.ones([input_shape[0], input_shape[1], 1])], axis=-1)

invokeai/backend/image_util/normal_bae/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from invokeai.backend.image_util.normal_bae.nets.NNET import NNET
1515
from invokeai.backend.image_util.util import np_to_pil, pil_to_np, resize_to_multiple
16+
from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device
1617

1718

1819
class NormalMapDetector:
@@ -64,7 +65,7 @@ def to(self, device: torch.device):
6465
def run(self, image: Image.Image):
6566
"""Processes an image and returns the detected normal map."""
6667

67-
device = next(iter(self.model.parameters())).device
68+
device = get_effective_device(self.model)
6869
np_image = pil_to_np(image)
6970

7071
height, width, _channels = np_image.shape

invokeai/backend/image_util/pidi/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from invokeai.backend.image_util.pidi.model import PiDiNet, pidinet
1313
from invokeai.backend.image_util.util import nms, normalize_image_channel_count, np_to_pil, pil_to_np, safe_step
14+
from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device
1415

1516

1617
class PIDINetDetector:
@@ -45,7 +46,7 @@ def run(
4546
) -> Image.Image:
4647
"""Processes an image and returns the detected edges."""
4748

48-
device = next(iter(self.model.parameters())).device
49+
device = get_effective_device(self.model)
4950

5051
np_img = pil_to_np(image)
5152
np_img = normalize_image_channel_count(np_img)

invokeai/backend/model_manager/load/load_base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,14 +68,14 @@ def model_on_device(self) -> Generator[Tuple[Optional[Dict[str, torch.Tensor]],
6868
"""Return a tuple consisting of the model's state dict (if it exists) and the locked model on execution device."""
6969
self._cache.lock(self._cache_record)
7070
try:
71-
yield (self._cache_record.state_dict, self._cache_record.model)
71+
yield (self._cache_record.cached_model.get_cpu_state_dict(), self._cache_record.cached_model.model)
7272
finally:
7373
self._cache.unlock(self._cache_record)
7474

7575
@property
7676
def model(self) -> AnyModel:
7777
"""Return the model without locking it."""
78-
return self._cache_record.model
78+
return self._cache_record.cached_model.model
7979

8080

8181
class LoadedModel(LoadedModelWithoutConfig):
Lines changed: 11 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,21 @@
11
from dataclasses import dataclass
2-
from typing import Any, Dict, Optional
32

4-
import torch
3+
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_only_full_load import (
4+
CachedModelOnlyFullLoad,
5+
)
6+
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_with_partial_load import (
7+
CachedModelWithPartialLoad,
8+
)
59

610

711
@dataclass
812
class CacheRecord:
9-
"""
10-
Elements of the cache:
11-
12-
key: Unique key for each model, same as used in the models database.
13-
model: Model in memory.
14-
state_dict: A read-only copy of the model's state dict in RAM. It will be
15-
used as a template for creating a copy in the VRAM.
16-
size: Size of the model
17-
loaded: True if the model's state dict is currently in VRAM
18-
19-
Before a model is executed, the state_dict template is copied into VRAM,
20-
and then injected into the model. When the model is finished, the VRAM
21-
copy of the state dict is deleted, and the RAM version is reinjected
22-
into the model.
23-
24-
The state_dict should be treated as a read-only attribute. Do not attempt
25-
to patch or otherwise modify it. Instead, patch the copy of the state_dict
26-
after it is loaded into the execution device (e.g. CUDA) using the `LoadedModel`
27-
context manager call `model_on_device()`.
28-
"""
13+
"""A class that represents a model in the model cache."""
2914

15+
# Cache key.
3016
key: str
31-
model: Any
32-
device: torch.device
33-
state_dict: Optional[Dict[str, torch.Tensor]]
34-
size: int
35-
loaded: bool = False
17+
# Model in memory.
18+
cached_model: CachedModelWithPartialLoad | CachedModelOnlyFullLoad
3619
_locks: int = 0
3720

3821
def lock(self) -> None:
@@ -45,6 +28,6 @@ def unlock(self) -> None:
4528
assert self._locks >= 0
4629

4730
@property
48-
def locked(self) -> bool:
31+
def is_locked(self) -> bool:
4932
"""Return true if record is locked."""
5033
return self._locks > 0

0 commit comments

Comments
 (0)