Skip to content

Commit 0258b6a

Browse files
authored
Partial Loading PR5: Dynamic cache ram/vram limits (#7509)
## Summary This PR enables RAM/VRAM cache size limits to be determined dynamically based on availability. **Config Changes** This PR modifies the app configs in the following ways: - A new `device_working_mem_gb` config was added. This is the amount of non-model working memory to keep available on the execution device (i.e. GPU) when using dynamic cache limits. It default to 3GB. - The `ram` and `vram` configs now default to `None`. If these configs are set, they will take precedence over the dynamic limits. **Note: Some users may have previously overriden the `ram` and `vram` values in their `invokeai.yaml`. They will need to remove these configs to enable the new dynamic limit feature.** **Working Memory** In addition to the new `device_working_mem_gb` config described above, memory-intensive operations can estimate the amount of working memory that they will need and request it from the model cache. This is currently applied to the VAE decoding step for all models. In the future, we may apply this to other operations as we work out which ops tend to exceed the default working memory reservation. **Mitigations for #7513 This PR includes some mitigations for the issue described in #7513. Without these mitigations, it would occur with higher frequency when dynamic RAM limits are used and the RAM is close to maxed-out. ## Limitations / Future Work - Only _models_ can be offloaded to RAM to conserve VRAM. I.e. if VAE decoding requires more working VRAM than available, the best we can do is keep the full model on the CPU, but we will still hit an OOM error. In the future, we could detect this ahead of time and switch to running inference on the CPU for those ops. - There is often a non-negligible amount of VRAM 'reserved' by the torch CUDA allocator, but not used by any allocated tensors. We may be able to tune the torch CUDA allocator to work better for our use case. Reference: https://pytorch.org/docs/stable/notes/cuda.html#optimizing-memory-usage-with-pytorch-cuda-alloc-conf - There may be some ops that require high working memory that haven't been updated to request extra memory yet. We will update these as we uncover them. - If a model is 'locked' in VRAM, it won't be partially unloaded if a later model load requests extra working memory. This should be uncommon, but I can think of cases where it would matter. ## Related Issues / Discussions - #7492 - #7494 - #7500 - #7505 ## QA Instructions Run a variety of models near the cache limits to ensure that model switching works properly for the following configurations: - [x] CUDA, `enable_partial_loading=true`, all other configs default (i.e. dynamic memory limits) - [x] CUDA, `enable_partial_loading=true`, CPU and CUDA memory reserved in another process so there is limited RAM/VRAM remaining, all other configs default (i.e. dynamic memory limits) - [x] CUDA, `enable_partial_loading=false`, all other configs default (i.e. dynamic memory limits) - [x] CUDA, ram/vram limits set (these should take precedence over the dynamic limits) - [x] MPS, all other default (i.e. dynamic memory limits) - [x] CPU, all other default (i.e. dynamic memory limits) ## Merge Plan - [x] Merge #7505 first and change target branch to main ## Checklist - [x] _The PR has a short but descriptive title, suitable for a changelog_ - [x] _Tests added / updated (if applicable)_ - [x] _Documentation added / updated (if applicable)_ - [ ] _Updated `What's New` copy (if doing a release after this PR)_
2 parents 87fdcb7 + d7ab464 commit 0258b6a

File tree

20 files changed

+314
-298
lines changed

20 files changed

+314
-298
lines changed

invokeai/app/api/routers/model_manager.py

Lines changed: 0 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import contextlib
55
import io
66
import pathlib
7-
import shutil
87
import traceback
98
from copy import deepcopy
109
from enum import Enum
@@ -21,7 +20,6 @@
2120
from typing_extensions import Annotated
2221

2322
from invokeai.app.api.dependencies import ApiDependencies
24-
from invokeai.app.services.config import get_config
2523
from invokeai.app.services.model_images.model_images_common import ModelImageFileNotFoundException
2624
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
2725
from invokeai.app.services.model_records import (
@@ -848,74 +846,6 @@ async def get_starter_models() -> StarterModelResponse:
848846
return StarterModelResponse(starter_models=starter_models, starter_bundles=starter_bundles)
849847

850848

851-
@model_manager_router.get(
852-
"/model_cache",
853-
operation_id="get_cache_size",
854-
response_model=float,
855-
summary="Get maximum size of model manager RAM or VRAM cache.",
856-
)
857-
async def get_cache_size(cache_type: CacheType = Query(description="The cache type", default=CacheType.RAM)) -> float:
858-
"""Return the current RAM or VRAM cache size setting (in GB)."""
859-
cache = ApiDependencies.invoker.services.model_manager.load.ram_cache
860-
value = 0.0
861-
if cache_type == CacheType.RAM:
862-
value = cache.max_cache_size
863-
elif cache_type == CacheType.VRAM:
864-
value = cache.max_vram_cache_size
865-
return value
866-
867-
868-
@model_manager_router.put(
869-
"/model_cache",
870-
operation_id="set_cache_size",
871-
response_model=float,
872-
summary="Set maximum size of model manager RAM or VRAM cache, optionally writing new value out to invokeai.yaml config file.",
873-
)
874-
async def set_cache_size(
875-
value: float = Query(description="The new value for the maximum cache size"),
876-
cache_type: CacheType = Query(description="The cache type", default=CacheType.RAM),
877-
persist: bool = Query(description="Write new value out to invokeai.yaml", default=False),
878-
) -> float:
879-
"""Set the current RAM or VRAM cache size setting (in GB). ."""
880-
cache = ApiDependencies.invoker.services.model_manager.load.ram_cache
881-
app_config = get_config()
882-
# Record initial state.
883-
vram_old = app_config.vram
884-
ram_old = app_config.ram
885-
886-
# Prepare target state.
887-
vram_new = vram_old
888-
ram_new = ram_old
889-
if cache_type == CacheType.RAM:
890-
ram_new = value
891-
elif cache_type == CacheType.VRAM:
892-
vram_new = value
893-
else:
894-
raise ValueError(f"Unexpected {cache_type=}.")
895-
896-
config_path = app_config.config_file_path
897-
new_config_path = config_path.with_suffix(".yaml.new")
898-
899-
try:
900-
# Try to apply the target state.
901-
cache.max_vram_cache_size = vram_new
902-
cache.max_cache_size = ram_new
903-
app_config.ram = ram_new
904-
app_config.vram = vram_new
905-
if persist:
906-
app_config.write_file(new_config_path)
907-
shutil.move(new_config_path, config_path)
908-
except Exception as e:
909-
# If there was a failure, restore the initial state.
910-
cache.max_cache_size = ram_old
911-
cache.max_vram_cache_size = vram_old
912-
app_config.ram = ram_old
913-
app_config.vram = vram_old
914-
915-
raise RuntimeError("Failed to update cache size") from e
916-
return value
917-
918-
919849
@model_manager_router.get(
920850
"/stats",
921851
operation_id="get_stats",

invokeai/app/invocations/compel.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,6 @@ class CompelInvocation(BaseInvocation):
6363

6464
@torch.no_grad()
6565
def invoke(self, context: InvocationContext) -> ConditioningOutput:
66-
tokenizer_info = context.models.load(self.clip.tokenizer)
67-
text_encoder_info = context.models.load(self.clip.text_encoder)
68-
6966
def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
7067
for lora in self.clip.loras:
7168
lora_info = context.models.load(lora.lora)
@@ -76,12 +73,13 @@ def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
7673

7774
# loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
7875

76+
text_encoder_info = context.models.load(self.clip.text_encoder)
7977
ti_list = generate_ti_list(self.prompt, text_encoder_info.config.base, context)
8078

8179
with (
8280
# apply all patches while the model is on the target device
8381
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
84-
tokenizer_info as tokenizer,
82+
context.models.load(self.clip.tokenizer) as tokenizer,
8583
LayerPatcher.apply_smart_model_patches(
8684
model=text_encoder,
8785
patches=_lora_loader(),
@@ -140,9 +138,7 @@ def run_clip_compel(
140138
lora_prefix: str,
141139
zero_on_empty: bool,
142140
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
143-
tokenizer_info = context.models.load(clip_field.tokenizer)
144141
text_encoder_info = context.models.load(clip_field.text_encoder)
145-
146142
# return zero on empty
147143
if prompt == "" and zero_on_empty:
148144
cpu_text_encoder = text_encoder_info.model
@@ -180,7 +176,7 @@ def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
180176
with (
181177
# apply all patches while the model is on the target device
182178
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
183-
tokenizer_info as tokenizer,
179+
context.models.load(clip_field.tokenizer) as tokenizer,
184180
LayerPatcher.apply_smart_model_patches(
185181
model=text_encoder,
186182
patches=_lora_loader(),
@@ -226,7 +222,6 @@ def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
226222

227223
del tokenizer
228224
del text_encoder
229-
del tokenizer_info
230225
del text_encoder_info
231226

232227
c = c.detach().to("cpu")

invokeai/app/invocations/denoise_latents.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -547,7 +547,6 @@ def prep_ip_adapter_image_prompts(
547547
for single_ip_adapter in ip_adapters:
548548
with context.models.load(single_ip_adapter.ip_adapter_model) as ip_adapter_model:
549549
assert isinstance(ip_adapter_model, IPAdapter)
550-
image_encoder_model_info = context.models.load(single_ip_adapter.image_encoder_model)
551550
# `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here.
552551
single_ipa_image_fields = single_ip_adapter.image
553552
if not isinstance(single_ipa_image_fields, list):
@@ -556,7 +555,7 @@ def prep_ip_adapter_image_prompts(
556555
single_ipa_images = [
557556
context.images.get_pil(image.image_name, mode="RGB") for image in single_ipa_image_fields
558557
]
559-
with image_encoder_model_info as image_encoder_model:
558+
with context.models.load(single_ip_adapter.image_encoder_model) as image_encoder_model:
560559
assert isinstance(image_encoder_model, CLIPVisionModelWithProjection)
561560
# Get image embeddings from CLIP and ImageProjModel.
562561
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds(
@@ -621,7 +620,6 @@ def run_t2i_adapters(
621620
t2i_adapter_data = []
622621
for t2i_adapter_field in t2i_adapter:
623622
t2i_adapter_model_config = context.models.get_config(t2i_adapter_field.t2i_adapter_model.key)
624-
t2i_adapter_loaded_model = context.models.load(t2i_adapter_field.t2i_adapter_model)
625623
image = context.images.get_pil(t2i_adapter_field.image.image_name, mode="RGB")
626624

627625
# The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally.
@@ -637,7 +635,7 @@ def run_t2i_adapters(
637635
raise ValueError(f"Unexpected T2I-Adapter base model type: '{t2i_adapter_model_config.base}'.")
638636

639637
t2i_adapter_model: T2IAdapter
640-
with t2i_adapter_loaded_model as t2i_adapter_model:
638+
with context.models.load(t2i_adapter_field.t2i_adapter_model) as t2i_adapter_model:
641639
total_downscale_factor = t2i_adapter_model.total_downscale_factor
642640

643641
# Note: We have hard-coded `do_classifier_free_guidance=False`. This is because we only want to prepare
@@ -926,10 +924,8 @@ def step_callback(state: PipelineIntermediateState) -> None:
926924
# ext: t2i/ip adapter
927925
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)
928926

929-
unet_info = context.models.load(self.unet.unet)
930-
assert isinstance(unet_info.model, UNet2DConditionModel)
931927
with (
932-
unet_info.model_on_device() as (cached_weights, unet),
928+
context.models.load(self.unet.unet).model_on_device() as (cached_weights, unet),
933929
ModelPatcher.patch_unet_attention_processor(unet, denoise_ctx.inputs.attention_processor_cls),
934930
# ext: controlnet
935931
ext_manager.patch_extensions(denoise_ctx),
@@ -995,11 +991,9 @@ def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
995991
del lora_info
996992
return
997993

998-
unet_info = context.models.load(self.unet.unet)
999-
assert isinstance(unet_info.model, UNet2DConditionModel)
1000994
with (
1001995
ExitStack() as exit_stack,
1002-
unet_info.model_on_device() as (cached_weights, unet),
996+
context.models.load(self.unet.unet).model_on_device() as (cached_weights, unet),
1003997
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
1004998
SeamlessExt.static_patch_model(unet, self.unet.seamless_axes), # FIXME
1005999
# Apply the LoRA after unet has been moved to its target device for faster patching.

invokeai/app/invocations/flux_denoise.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -199,8 +199,8 @@ def _run_diffusion(
199199
else None
200200
)
201201

202-
transformer_info = context.models.load(self.transformer.transformer)
203-
is_schnell = "schnell" in getattr(transformer_info.config, "config_path", "")
202+
transformer_config = context.models.get_config(self.transformer.transformer)
203+
is_schnell = "schnell" in getattr(transformer_config, "config_path", "")
204204

205205
# Calculate the timestep schedule.
206206
timesteps = get_schedule(
@@ -299,9 +299,11 @@ def _run_diffusion(
299299
)
300300

301301
# Load the transformer model.
302-
(cached_weights, transformer) = exit_stack.enter_context(transformer_info.model_on_device())
302+
(cached_weights, transformer) = exit_stack.enter_context(
303+
context.models.load(self.transformer.transformer).model_on_device()
304+
)
303305
assert isinstance(transformer, Flux)
304-
config = transformer_info.config
306+
config = transformer_config
305307
assert config is not None
306308

307309
# Determine if the model is quantized.
@@ -512,15 +514,18 @@ def _prep_controlnet_extensions(
512514
# before loading the models. Then make sure that all VAE encoding is done before loading the ControlNets to
513515
# minimize peak memory.
514516

515-
# First, load the ControlNet models so that we can determine the ControlNet types.
516-
controlnet_models = [context.models.load(controlnet.control_model) for controlnet in controlnets]
517-
518517
# Calculate the controlnet conditioning tensors.
519518
# We do this before loading the ControlNet models because it may require running the VAE, and we are trying to
520519
# keep peak memory down.
521520
controlnet_conds: list[torch.Tensor] = []
522-
for controlnet, controlnet_model in zip(controlnets, controlnet_models, strict=True):
521+
for controlnet in controlnets:
523522
image = context.images.get_pil(controlnet.image.image_name)
523+
524+
# HACK(ryand): We have to load the ControlNet model to determine whether the VAE needs to be run. We really
525+
# shouldn't have to load the model here. There's a risk that the model will be dropped from the model cache
526+
# before we load it into VRAM and thus we'll have to load it again (context:
527+
# https://github.com/invoke-ai/InvokeAI/issues/7513).
528+
controlnet_model = context.models.load(controlnet.control_model)
524529
if isinstance(controlnet_model.model, InstantXControlNetFlux):
525530
if self.controlnet_vae is None:
526531
raise ValueError("A ControlNet VAE is required when using an InstantX FLUX ControlNet.")
@@ -550,10 +555,8 @@ def _prep_controlnet_extensions(
550555

551556
# Finally, load the ControlNet models and initialize the ControlNet extensions.
552557
controlnet_extensions: list[XLabsControlNetExtension | InstantXControlNetExtension] = []
553-
for controlnet, controlnet_cond, controlnet_model in zip(
554-
controlnets, controlnet_conds, controlnet_models, strict=True
555-
):
556-
model = exit_stack.enter_context(controlnet_model)
558+
for controlnet, controlnet_cond in zip(controlnets, controlnet_conds, strict=True):
559+
model = exit_stack.enter_context(context.models.load(controlnet.control_model))
557560

558561
if isinstance(model, XLabsControlNetFlux):
559562
controlnet_extensions.append(

invokeai/app/invocations/flux_text_encoder.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -69,14 +69,11 @@ def invoke(self, context: InvocationContext) -> FluxConditioningOutput:
6969
)
7070

7171
def _t5_encode(self, context: InvocationContext) -> torch.Tensor:
72-
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
73-
t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder)
74-
7572
prompt = [self.prompt]
7673

7774
with (
78-
t5_text_encoder_info as t5_text_encoder,
79-
t5_tokenizer_info as t5_tokenizer,
75+
context.models.load(self.t5_encoder.text_encoder) as t5_text_encoder,
76+
context.models.load(self.t5_encoder.tokenizer) as t5_tokenizer,
8077
):
8178
assert isinstance(t5_text_encoder, T5EncoderModel)
8279
assert isinstance(t5_tokenizer, T5Tokenizer)
@@ -90,22 +87,20 @@ def _t5_encode(self, context: InvocationContext) -> torch.Tensor:
9087
return prompt_embeds
9188

9289
def _clip_encode(self, context: InvocationContext) -> torch.Tensor:
93-
clip_tokenizer_info = context.models.load(self.clip.tokenizer)
94-
clip_text_encoder_info = context.models.load(self.clip.text_encoder)
95-
9690
prompt = [self.prompt]
9791

92+
clip_text_encoder_info = context.models.load(self.clip.text_encoder)
93+
clip_text_encoder_config = clip_text_encoder_info.config
94+
assert clip_text_encoder_config is not None
95+
9896
with (
9997
clip_text_encoder_info.model_on_device() as (cached_weights, clip_text_encoder),
100-
clip_tokenizer_info as clip_tokenizer,
98+
context.models.load(self.clip.tokenizer) as clip_tokenizer,
10199
ExitStack() as exit_stack,
102100
):
103101
assert isinstance(clip_text_encoder, CLIPTextModel)
104102
assert isinstance(clip_tokenizer, CLIPTokenizer)
105103

106-
clip_text_encoder_config = clip_text_encoder_info.config
107-
assert clip_text_encoder_config is not None
108-
109104
# Apply LoRA models to the CLIP encoder.
110105
# Note: We apply the LoRA after the transformer has been moved to its target device for faster patching.
111106
if clip_text_encoder_config.format in [ModelFormat.Diffusers]:

invokeai/app/invocations/flux_vae_decode.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from PIL import Image
44

55
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
6+
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
67
from invokeai.app.invocations.fields import (
78
FieldDescriptions,
89
Input,
@@ -24,7 +25,7 @@
2425
title="FLUX Latents to Image",
2526
tags=["latents", "image", "vae", "l2i", "flux"],
2627
category="latents",
27-
version="1.0.0",
28+
version="1.0.1",
2829
)
2930
class FluxVaeDecodeInvocation(BaseInvocation, WithMetadata, WithBoard):
3031
"""Generates an image from latents."""
@@ -38,8 +39,23 @@ class FluxVaeDecodeInvocation(BaseInvocation, WithMetadata, WithBoard):
3839
input=Input.Connection,
3940
)
4041

42+
def _estimate_working_memory(self, latents: torch.Tensor, vae: AutoEncoder) -> int:
43+
"""Estimate the working memory required by the invocation in bytes."""
44+
# It was found experimentally that the peak working memory scales linearly with the number of pixels and the
45+
# element size (precision).
46+
out_h = LATENT_SCALE_FACTOR * latents.shape[-2]
47+
out_w = LATENT_SCALE_FACTOR * latents.shape[-1]
48+
element_size = next(vae.parameters()).element_size()
49+
scaling_constant = 1090 # Determined experimentally.
50+
working_memory = out_h * out_w * element_size * scaling_constant
51+
52+
# We add a 20% buffer to the working memory estimate to be safe.
53+
working_memory = working_memory * 1.2
54+
return int(working_memory)
55+
4156
def _vae_decode(self, vae_info: LoadedModel, latents: torch.Tensor) -> Image.Image:
42-
with vae_info as vae:
57+
estimated_working_memory = self._estimate_working_memory(latents, vae_info.model)
58+
with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae):
4359
assert isinstance(vae, AutoEncoder)
4460
vae_dtype = next(iter(vae.parameters())).dtype
4561
latents = latents.to(device=TorchDevice.choose_torch_device(), dtype=vae_dtype)

0 commit comments

Comments
 (0)