Skip to content

Commit 6bf5b74

Browse files
authored
Partial Loading PR2: Add utils to support partial loading of models from CPU to GPU (#7494)
## Summary This PR adds utilities to support partial loading of models from CPU to GPU. The new utilities are not yet being used by the ModelCache, so there should be no functional behavior changes in this PR. Detailed changes: - Add autocast modules that are designed to wrap common `torch.nn.Module`s and enable them to run with automatic device casting. E.g. a linear layer on the CPU can be executed with an input tensor on the GPU by streaming the weights to the GPU at runtime. - Add unit tests for the aforementioned autocast modules to verify that they work for all supported quantization formats (GGUF, BnB NF4, BnB LLM.int8()). - Add `CachedModelWithPartialLoad` and `CachedModelOnlyFullLoad` classes to manage partial loading at the model level. ## Alternative Implementations Several options were explored for supporting inference on partially-loaded models. The pros/cons of the explored options are summarized here for reference. In the end, wrapper modules were selected as the best overall solution for our use case. Option 1: Re-implement the .forward() methods of modules to add support for device conversions - This is the option implemented in this PR. - This approach is the most manual of the three, but as a result offers the broadest compatibility with unusual model types. It is manual in that we have to explicitly add support for all module types that we wish to support. Fortunately, the list of foundational module types is relatively small (e.g. the current set of implemented layers covers all but 0.04 MB of the full FLUX model.). Option 2: Implement a custom Tensor type that casts tensors to a `target_device` each time the tensor is used - This approach has the nice property that it is injected at the tensor level, and the model does not need to be modified in any way. - One challenge with this approach is handling interactions with other custom tensor types (e.g. GGMLTensor). This problem is solvable, but definitely introduces a layer of complexity. (There are likely to also be some similar issues with interactions with the BnB quantization, but I didn't get as far as testing BnB.) Option 3: Override the `__torch_function__` dispatch calls globally and cast all params to the execution device. - This approach is nice and simple: just apply a global context manager and all operations will happen on the compute device regardless of the device of the participating tensors. - Challenges: - Overriding the `__torch_function__` dispatch calls introduces some overhead even if the tensors are already on the correct device. - It is difficult to manage the autocasting context manager. E.g. it is tempting to apply it to the model's `.forward(...)` method, but we use some models with non-standard entrypoints. And we don't want to end up with nested autocasting context managers. - BnB applies quantization side effects when a param is moved to the GPU - this interacts in unexpected ways with a global context manager. ## QA Instructions Most of the changes in this PR should not impact active code, and thus should not cause any changes to behavior. The main risks come from bumping the bitsandbytes dependency and some minor modifications to the bitsandbytes quantization code. - [x] Regression test bitsandbytes NF4 quantization - [x] Regression test bitsandbytes LLM.int8() quantization - [x] Regression test on MacOS (to ensure that there are no lingering bitsandbytes import errors) I also tested the new utilities for inference on full models in another branch to validate that there were not major issues. This functionality will be tested more thoroughly in a future PR. ## Merge Plan - [x] #7492 should be merged first so that the target branch can be updated to main. ## Checklist - [x] _The PR has a short but descriptive title, suitable for a changelog_ - [x] _Tests added / updated (if applicable)_ - [x] _Documentation added / updated (if applicable)_ - [ ] _Updated `What's New` copy (if doing a release after this PR)_
2 parents d3916db + 0fc5387 commit 6bf5b74

File tree

16 files changed

+1302
-8
lines changed

16 files changed

+1302
-8
lines changed
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
from typing import Any
2+
3+
import torch
4+
5+
6+
class CachedModelOnlyFullLoad:
7+
"""A wrapper around a PyTorch model to handle full loads and unloads between the CPU and the compute device.
8+
Note: "VRAM" is used throughout this class to refer to the memory on the compute device. It could be CUDA memory,
9+
MPS memory, etc.
10+
"""
11+
12+
def __init__(self, model: torch.nn.Module | Any, compute_device: torch.device, total_bytes: int):
13+
"""Initialize a CachedModelOnlyFullLoad.
14+
Args:
15+
model (torch.nn.Module | Any): The model to wrap. Should be on the CPU.
16+
compute_device (torch.device): The compute device to move the model to.
17+
total_bytes (int): The total size (in bytes) of all the weights in the model.
18+
"""
19+
# model is often a torch.nn.Module, but could be any model type. Throughout this class, we handle both cases.
20+
self._model = model
21+
self._compute_device = compute_device
22+
self._offload_device = torch.device("cpu")
23+
24+
# A CPU read-only copy of the model's state dict.
25+
self._cpu_state_dict: dict[str, torch.Tensor] | None = None
26+
if isinstance(model, torch.nn.Module):
27+
self._cpu_state_dict = model.state_dict()
28+
29+
self._total_bytes = total_bytes
30+
self._is_in_vram = False
31+
32+
@property
33+
def model(self) -> torch.nn.Module:
34+
return self._model
35+
36+
def get_cpu_state_dict(self) -> dict[str, torch.Tensor] | None:
37+
"""Get a read-only copy of the model's state dict in RAM."""
38+
# TODO(ryand): Document this better.
39+
return self._cpu_state_dict
40+
41+
def total_bytes(self) -> int:
42+
"""Get the total size (in bytes) of all the weights in the model."""
43+
return self._total_bytes
44+
45+
def cur_vram_bytes(self) -> int:
46+
"""Get the size (in bytes) of the weights that are currently in VRAM."""
47+
if self._is_in_vram:
48+
return self._total_bytes
49+
else:
50+
return 0
51+
52+
def is_in_vram(self) -> bool:
53+
"""Return true if the model is currently in VRAM."""
54+
return self._is_in_vram
55+
56+
def full_load_to_vram(self) -> int:
57+
"""Load all weights into VRAM (if supported by the model).
58+
Returns:
59+
The number of bytes loaded into VRAM.
60+
"""
61+
if self._is_in_vram:
62+
# Already in VRAM.
63+
return 0
64+
65+
if not hasattr(self._model, "to"):
66+
# Model doesn't support moving to a device.
67+
return 0
68+
69+
if self._cpu_state_dict is not None:
70+
new_state_dict: dict[str, torch.Tensor] = {}
71+
for k, v in self._cpu_state_dict.items():
72+
new_state_dict[k] = v.to(self._compute_device, copy=True)
73+
self._model.load_state_dict(new_state_dict, assign=True)
74+
self._model.to(self._compute_device)
75+
76+
self._is_in_vram = True
77+
return self._total_bytes
78+
79+
def full_unload_from_vram(self) -> int:
80+
"""Unload all weights from VRAM.
81+
Returns:
82+
The number of bytes unloaded from VRAM.
83+
"""
84+
if not self._is_in_vram:
85+
# Already in RAM.
86+
return 0
87+
88+
if self._cpu_state_dict is not None:
89+
self._model.load_state_dict(self._cpu_state_dict, assign=True)
90+
self._model.to(self._offload_device)
91+
92+
self._is_in_vram = False
93+
return self._total_bytes
Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
import torch
2+
3+
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch_module_autocast import (
4+
AUTOCAST_MODULE_TYPE_MAPPING,
5+
apply_custom_layers_to_model,
6+
remove_custom_layers_from_model,
7+
)
8+
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
9+
from invokeai.backend.util.logging import InvokeAILogger
10+
11+
12+
def set_nested_attr(obj: object, attr: str, value: object):
13+
"""A helper function that extends setattr() to support nested attributes.
14+
15+
Example:
16+
set_nested_attr(model, "module.encoder.conv1.weight", new_conv1_weight)
17+
"""
18+
attrs = attr.split(".")
19+
for attr in attrs[:-1]:
20+
obj = getattr(obj, attr)
21+
setattr(obj, attrs[-1], value)
22+
23+
24+
class CachedModelWithPartialLoad:
25+
"""A wrapper around a PyTorch model to handle partial loads and unloads between the CPU and the compute device.
26+
27+
Note: "VRAM" is used throughout this class to refer to the memory on the compute device. It could be CUDA memory,
28+
MPS memory, etc.
29+
"""
30+
31+
def __init__(self, model: torch.nn.Module, compute_device: torch.device):
32+
self._model = model
33+
self._compute_device = compute_device
34+
35+
# A CPU read-only copy of the model's state dict.
36+
self._cpu_state_dict: dict[str, torch.Tensor] = model.state_dict()
37+
38+
# TODO(ryand): Handle the case where the model sizes changes after initial load (e.g. due to dtype casting).
39+
# Consider how we should handle this for both self._total_bytes and self._cur_vram_bytes.
40+
self._total_bytes = sum(calc_tensor_size(p) for p in self._cpu_state_dict.values())
41+
self._cur_vram_bytes: int | None = None
42+
43+
self._modules_that_support_autocast = self._find_modules_that_support_autocast()
44+
self._keys_in_modules_that_do_not_support_autocast = self._find_keys_in_modules_that_do_not_support_autocast()
45+
46+
def _find_modules_that_support_autocast(self) -> dict[str, torch.nn.Module]:
47+
"""Find all modules that support autocasting."""
48+
return {n: m for n, m in self._model.named_modules() if type(m) in AUTOCAST_MODULE_TYPE_MAPPING}
49+
50+
def _find_keys_in_modules_that_do_not_support_autocast(self) -> set[str]:
51+
keys_in_modules_that_do_not_support_autocast = set()
52+
for key in self._cpu_state_dict.keys():
53+
for module_name in self._modules_that_support_autocast.keys():
54+
if key.startswith(module_name):
55+
break
56+
else:
57+
keys_in_modules_that_do_not_support_autocast.add(key)
58+
return keys_in_modules_that_do_not_support_autocast
59+
60+
def _move_non_persistent_buffers_to_device(self, device: torch.device):
61+
"""Move the non-persistent buffers to the target device. These buffers are not included in the state dict,
62+
so we need to move them manually.
63+
"""
64+
# HACK(ryand): Typically, non-persistent buffers are moved when calling module.to(device). We don't move entire
65+
# modules, because we manage the devices of individual tensors using the state dict. Since non-persistent
66+
# buffers are not included in the state dict, we need to handle them manually. The only way to do this is by
67+
# using private torch.nn.Module attributes.
68+
for module in self._model.modules():
69+
for name, buffer in module.named_buffers():
70+
if name in module._non_persistent_buffers_set:
71+
module._buffers[name] = buffer.to(device, copy=True)
72+
73+
@property
74+
def model(self) -> torch.nn.Module:
75+
return self._model
76+
77+
def get_cpu_state_dict(self) -> dict[str, torch.Tensor] | None:
78+
"""Get a read-only copy of the model's state dict in RAM."""
79+
# TODO(ryand): Document this better.
80+
return self._cpu_state_dict
81+
82+
def total_bytes(self) -> int:
83+
"""Get the total size (in bytes) of all the weights in the model."""
84+
return self._total_bytes
85+
86+
def cur_vram_bytes(self) -> int:
87+
"""Get the size (in bytes) of the weights that are currently in VRAM."""
88+
if self._cur_vram_bytes is None:
89+
cur_state_dict = self._model.state_dict()
90+
self._cur_vram_bytes = sum(
91+
calc_tensor_size(p) for p in cur_state_dict.values() if p.device.type == self._compute_device.type
92+
)
93+
return self._cur_vram_bytes
94+
95+
def full_load_to_vram(self) -> int:
96+
"""Load all weights into VRAM."""
97+
return self.partial_load_to_vram(self.total_bytes())
98+
99+
def full_unload_from_vram(self) -> int:
100+
"""Unload all weights from VRAM."""
101+
return self.partial_unload_from_vram(self.total_bytes())
102+
103+
@torch.no_grad()
104+
def partial_load_to_vram(self, vram_bytes_to_load: int) -> int:
105+
"""Load more weights into VRAM without exceeding vram_bytes_to_load.
106+
107+
Returns:
108+
The number of bytes loaded into VRAM.
109+
"""
110+
# TODO(ryand): Handle the case where an exception is thrown while loading or unloading weights. At the very
111+
# least, we should reset self._cur_vram_bytes to None.
112+
113+
vram_bytes_loaded = 0
114+
115+
cur_state_dict = self._model.state_dict()
116+
117+
# First, process the keys *must* be loaded into VRAM.
118+
for key in self._keys_in_modules_that_do_not_support_autocast:
119+
param = cur_state_dict[key]
120+
if param.device.type == self._compute_device.type:
121+
continue
122+
123+
param_size = calc_tensor_size(param)
124+
cur_state_dict[key] = param.to(self._compute_device, copy=True)
125+
vram_bytes_loaded += param_size
126+
127+
if vram_bytes_loaded > vram_bytes_to_load:
128+
logger = InvokeAILogger.get_logger()
129+
logger.warning(
130+
f"Loaded {vram_bytes_loaded / 2**20} MB into VRAM, but only {vram_bytes_to_load / 2**20} MB were "
131+
"requested. This is the minimum set of weights in VRAM required to run the model."
132+
)
133+
134+
# Next, process the keys that can optionally be loaded into VRAM.
135+
fully_loaded = True
136+
for key, param in cur_state_dict.items():
137+
if param.device.type == self._compute_device.type:
138+
continue
139+
140+
param_size = calc_tensor_size(param)
141+
if vram_bytes_loaded + param_size > vram_bytes_to_load:
142+
# TODO(ryand): Should we just break here? If we couldn't fit this parameter into VRAM, is it really
143+
# worth continuing to search for a smaller parameter that would fit?
144+
fully_loaded = False
145+
continue
146+
147+
cur_state_dict[key] = param.to(self._compute_device, copy=True)
148+
vram_bytes_loaded += param_size
149+
150+
if vram_bytes_loaded > 0:
151+
# We load the entire state dict, not just the parameters that changed, in case there are modules that
152+
# override _load_from_state_dict() and do some funky stuff that requires the entire state dict.
153+
# Alternatively, in the future, grouping parameters by module could probably solve this problem.
154+
self._model.load_state_dict(cur_state_dict, assign=True)
155+
156+
if self._cur_vram_bytes is not None:
157+
self._cur_vram_bytes += vram_bytes_loaded
158+
159+
if fully_loaded:
160+
remove_custom_layers_from_model(self._model)
161+
# TODO(ryand): Warn if the self.cur_vram_bytes() and self.total_bytes() are out of sync.
162+
else:
163+
apply_custom_layers_to_model(self._model)
164+
165+
# Move all non-persistent buffers to the compute device. These are a weird edge case and do not participate in
166+
# the vram_bytes_loaded tracking.
167+
self._move_non_persistent_buffers_to_device(self._compute_device)
168+
169+
return vram_bytes_loaded
170+
171+
@torch.no_grad()
172+
def partial_unload_from_vram(self, vram_bytes_to_free: int) -> int:
173+
"""Unload weights from VRAM until vram_bytes_to_free bytes are freed. Or the entire model is unloaded.
174+
175+
Returns:
176+
The number of bytes unloaded from VRAM.
177+
"""
178+
vram_bytes_freed = 0
179+
180+
offload_device = "cpu"
181+
cur_state_dict = self._model.state_dict()
182+
for key, param in cur_state_dict.items():
183+
if vram_bytes_freed >= vram_bytes_to_free:
184+
break
185+
186+
if param.device.type == offload_device:
187+
continue
188+
189+
cur_state_dict[key] = self._cpu_state_dict[key]
190+
vram_bytes_freed += calc_tensor_size(param)
191+
192+
if vram_bytes_freed > 0:
193+
self._model.load_state_dict(cur_state_dict, assign=True)
194+
195+
if self._cur_vram_bytes is not None:
196+
self._cur_vram_bytes -= vram_bytes_freed
197+
198+
# We may have gone from a fully-loaded model to a partially-loaded model, so we need to reapply the custom
199+
# layers.
200+
apply_custom_layers_to_model(self._model)
201+
return vram_bytes_freed

invokeai/backend/model_manager/load/model_cache/torch_module_autocast/__init__.py

Whitespace-only changes.
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import torch
2+
3+
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
4+
5+
# This file contains custom torch.nn.Module classes that support streaming of weights to the target device.
6+
# Each class sub-classes the original module type that is is replacing, so the following properties are preserved:
7+
# - isinstance(m, torch.nn.OrginalModule) should still work.
8+
# - Patching the weights (e.g. for LoRA) should still work if non-quantized.
9+
10+
11+
class CustomLinear(torch.nn.Linear):
12+
def forward(self, input: torch.Tensor) -> torch.Tensor:
13+
weight = cast_to_device(self.weight, input.device)
14+
bias = cast_to_device(self.bias, input.device)
15+
return torch.nn.functional.linear(input, weight, bias)
16+
17+
18+
class CustomConv1d(torch.nn.Conv1d):
19+
def forward(self, input: torch.Tensor) -> torch.Tensor:
20+
weight = cast_to_device(self.weight, input.device)
21+
bias = cast_to_device(self.bias, input.device)
22+
return self._conv_forward(input, weight, bias)
23+
24+
25+
class CustomConv2d(torch.nn.Conv2d):
26+
def forward(self, input: torch.Tensor) -> torch.Tensor:
27+
weight = cast_to_device(self.weight, input.device)
28+
bias = cast_to_device(self.bias, input.device)
29+
return self._conv_forward(input, weight, bias)
30+
31+
32+
class CustomGroupNorm(torch.nn.GroupNorm):
33+
def forward(self, input: torch.Tensor) -> torch.Tensor:
34+
weight = cast_to_device(self.weight, input.device)
35+
bias = cast_to_device(self.bias, input.device)
36+
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
37+
38+
39+
class CustomEmbedding(torch.nn.Embedding):
40+
def forward(self, input: torch.Tensor) -> torch.Tensor:
41+
weight = cast_to_device(self.weight, input.device)
42+
return torch.nn.functional.embedding(
43+
input,
44+
weight,
45+
self.padding_idx,
46+
self.max_norm,
47+
self.norm_type,
48+
self.scale_grad_by_freq,
49+
self.sparse,
50+
)
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from typing import TypeVar
2+
3+
import torch
4+
5+
T = TypeVar("T", torch.Tensor, None, torch.Tensor | None)
6+
7+
8+
def cast_to_device(t: T, to_device: torch.device) -> T:
9+
"""Helper function to cast an optional tensor to a target device."""
10+
if t is None:
11+
return t
12+
13+
if t.device.type != to_device.type:
14+
return t.to(to_device)
15+
return t
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import bitsandbytes as bnb
2+
import torch
3+
4+
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
5+
from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt
6+
7+
8+
class CustomInvokeLinear8bitLt(InvokeLinear8bitLt):
9+
def forward(self, x: torch.Tensor) -> torch.Tensor:
10+
matmul_state = bnb.MatmulLtState()
11+
matmul_state.threshold = self.state.threshold
12+
matmul_state.has_fp16_weights = self.state.has_fp16_weights
13+
matmul_state.use_pool = self.state.use_pool
14+
matmul_state.is_training = self.training
15+
# The underlying InvokeInt8Params weight must already be quantized.
16+
assert self.weight.CB is not None
17+
matmul_state.CB = cast_to_device(self.weight.CB, x.device)
18+
matmul_state.SCB = cast_to_device(self.weight.SCB, x.device)
19+
20+
# weights are cast automatically as Int8Params, but the bias has to be cast manually.
21+
if self.bias is not None and self.bias.dtype != x.dtype:
22+
self.bias.data = self.bias.data.to(x.dtype)
23+
24+
# NOTE(ryand): The second parameter should not be needed at all given our expected inference configuration, but
25+
# it's dtype field must be accessible, even though it's not used. We pass in self.weight even though it could be
26+
# on the wrong device.
27+
return bnb.matmul(x, self.weight, bias=cast_to_device(self.bias, x.device), state=matmul_state)

0 commit comments

Comments
 (0)