Skip to content

Commit b9854c7

Browse files
RyanJDickbrandonrising
authored andcommitted
Rename ModelPatcher methods to reflect that they are general model patching methods and are not LoRA-specific.
1 parent 41dcbb4 commit b9854c7

File tree

9 files changed

+24
-24
lines changed

9 files changed

+24
-24
lines changed

invokeai/app/invocations/compel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
8282
# apply all patches while the model is on the target device
8383
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
8484
tokenizer_info as tokenizer,
85-
ModelPatcher.apply_lora_patches(
85+
ModelPatcher.apply_model_patches(
8686
model=text_encoder,
8787
patches=_lora_loader(),
8888
prefix="lora_te_",
@@ -179,7 +179,7 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
179179
# apply all patches while the model is on the target device
180180
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
181181
tokenizer_info as tokenizer,
182-
ModelPatcher.apply_lora_patches(
182+
ModelPatcher.apply_model_patches(
183183
text_encoder,
184184
patches=_lora_loader(),
185185
prefix=lora_prefix,

invokeai/app/invocations/denoise_latents.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1003,7 +1003,7 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
10031003
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
10041004
SeamlessExt.static_patch_model(unet, self.unet.seamless_axes), # FIXME
10051005
# Apply the LoRA after unet has been moved to its target device for faster patching.
1006-
ModelPatcher.apply_lora_patches(
1006+
ModelPatcher.apply_model_patches(
10071007
model=unet,
10081008
patches=_lora_loader(),
10091009
prefix="lora_unet_",

invokeai/app/invocations/flux_denoise.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ def _run_diffusion(
311311
if config.format in [ModelFormat.Checkpoint]:
312312
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
313313
exit_stack.enter_context(
314-
ModelPatcher.apply_lora_patches(
314+
ModelPatcher.apply_model_patches(
315315
model=transformer,
316316
patches=self._lora_iterator(context),
317317
prefix=FLUX_LORA_TRANSFORMER_PREFIX,
@@ -326,7 +326,7 @@ def _run_diffusion(
326326
# The model is quantized, so apply the LoRA weights as sidecar layers. This results in slower inference,
327327
# than directly patching the weights, but is agnostic to the quantization format.
328328
exit_stack.enter_context(
329-
ModelPatcher.apply_lora_sidecar_patches(
329+
ModelPatcher.apply_model_sidecar_patches(
330330
model=transformer,
331331
patches=self._lora_iterator(context),
332332
prefix=FLUX_LORA_TRANSFORMER_PREFIX,

invokeai/app/invocations/flux_text_encoder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def _clip_encode(self, context: InvocationContext) -> torch.Tensor:
111111
if clip_text_encoder_config.format in [ModelFormat.Diffusers]:
112112
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
113113
exit_stack.enter_context(
114-
ModelPatcher.apply_lora_patches(
114+
ModelPatcher.apply_model_patches(
115115
model=clip_text_encoder,
116116
patches=self._clip_lora_iterator(context),
117117
prefix=FLUX_LORA_CLIP_PREFIX,

invokeai/app/invocations/sd3_text_encoder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def _clip_encode(
150150
if clip_text_encoder_config.format in [ModelFormat.Diffusers]:
151151
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
152152
exit_stack.enter_context(
153-
ModelPatcher.apply_lora_patches(
153+
ModelPatcher.apply_model_patches(
154154
model=clip_text_encoder,
155155
patches=self._clip_lora_iterator(context, clip_model),
156156
prefix=FLUX_LORA_CLIP_PREFIX,

invokeai/app/invocations/tiled_multi_diffusion_denoise_latents.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
207207
with (
208208
ExitStack() as exit_stack,
209209
unet_info as unet,
210-
ModelPatcher.apply_lora_patches(model=unet, patches=_lora_loader(), prefix="lora_unet_"),
210+
ModelPatcher.apply_model_patches(model=unet, patches=_lora_loader(), prefix="lora_unet_"),
211211
):
212212
assert isinstance(unet, UNet2DConditionModel)
213213
latents = latents.to(device=unet.device, dtype=unet.dtype)

invokeai/backend/patches/model_patcher.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class ModelPatcher:
1717
@staticmethod
1818
@torch.no_grad()
1919
@contextmanager
20-
def apply_lora_patches(
20+
def apply_model_patches(
2121
model: torch.nn.Module,
2222
patches: Iterable[Tuple[LoRAModelRaw, float]],
2323
prefix: str,
@@ -37,7 +37,7 @@ def apply_lora_patches(
3737
original_weights = OriginalWeightsStorage(cached_weights)
3838
try:
3939
for patch, patch_weight in patches:
40-
ModelPatcher.apply_lora_patch(
40+
ModelPatcher.apply_model_patch(
4141
model=model,
4242
prefix=prefix,
4343
patch=patch,
@@ -54,7 +54,7 @@ def apply_lora_patches(
5454

5555
@staticmethod
5656
@torch.no_grad()
57-
def apply_lora_patch(
57+
def apply_model_patch(
5858
model: torch.nn.Module,
5959
prefix: str,
6060
patch: LoRAModelRaw,
@@ -89,7 +89,7 @@ def apply_lora_patch(
8989
model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened
9090
)
9191

92-
ModelPatcher._apply_lora_layer_patch(
92+
ModelPatcher._apply_model_layer_patch(
9393
module_to_patch=module,
9494
module_to_patch_key=module_key,
9595
patch=layer,
@@ -99,7 +99,7 @@ def apply_lora_patch(
9999

100100
@staticmethod
101101
@torch.no_grad()
102-
def _apply_lora_layer_patch(
102+
def _apply_model_layer_patch(
103103
module_to_patch: torch.nn.Module,
104104
module_to_patch_key: str,
105105
patch: BaseLayerPatch,
@@ -146,7 +146,7 @@ def _apply_lora_layer_patch(
146146
@staticmethod
147147
@torch.no_grad()
148148
@contextmanager
149-
def apply_lora_sidecar_patches(
149+
def apply_model_sidecar_patches(
150150
model: torch.nn.Module,
151151
patches: Iterable[Tuple[LoRAModelRaw, float]],
152152
prefix: str,
@@ -169,7 +169,7 @@ def apply_lora_sidecar_patches(
169169
original_modules: dict[str, torch.nn.Module] = {}
170170
try:
171171
for patch, patch_weight in patches:
172-
ModelPatcher._apply_lora_sidecar_patch(
172+
ModelPatcher._apply_model_sidecar_patch(
173173
model=model,
174174
prefix=prefix,
175175
patch=patch,
@@ -187,7 +187,7 @@ def apply_lora_sidecar_patches(
187187
ModelPatcher._set_submodule(parent_module, module_name, orig_module)
188188

189189
@staticmethod
190-
def _apply_lora_sidecar_patch(
190+
def _apply_model_sidecar_patch(
191191
model: torch.nn.Module,
192192
patch: LoRAModelRaw,
193193
patch_weight: float,
@@ -216,7 +216,7 @@ def _apply_lora_sidecar_patch(
216216
model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened
217217
)
218218

219-
ModelPatcher._apply_lora_layer_wrapper_patch(
219+
ModelPatcher._apply_model_layer_wrapper_patch(
220220
model=model,
221221
module_to_patch=module,
222222
module_to_patch_key=module_key,
@@ -228,7 +228,7 @@ def _apply_lora_sidecar_patch(
228228

229229
@staticmethod
230230
@torch.no_grad()
231-
def _apply_lora_layer_wrapper_patch(
231+
def _apply_model_layer_wrapper_patch(
232232
model: torch.nn.Module,
233233
module_to_patch: torch.nn.Module,
234234
module_to_patch_key: str,

invokeai/backend/stable_diffusion/extensions/lora.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def __init__(
3131
def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage):
3232
lora_model = self._node_context.models.load(self._model_id).model
3333
assert isinstance(lora_model, LoRAModelRaw)
34-
ModelPatcher.apply_lora_patch(
34+
ModelPatcher.apply_model_patch(
3535
model=unet,
3636
prefix="lora_unet_",
3737
patch=lora_model,

tests/backend/patches/test_lora_patcher.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def test_apply_lora_patches(device: str, num_layers: int):
5353
orig_linear_weight = model.linear_layer_1.weight.data.detach().clone()
5454
expected_patched_linear_weight = orig_linear_weight + (lora_rank * lora_weight * num_layers)
5555

56-
with ModelPatcher.apply_lora_patches(model=model, patches=lora_models, prefix=""):
56+
with ModelPatcher.apply_model_patches(model=model, patches=lora_models, prefix=""):
5757
# After patching, all LoRA layer weights should have been moved back to the cpu.
5858
for lora, _ in lora_models:
5959
assert lora.layers["linear_layer_1"].up.device.type == "cpu"
@@ -93,7 +93,7 @@ def test_apply_lora_patches_change_device():
9393

9494
orig_linear_weight = model.linear_layer_1.weight.data.detach().clone()
9595

96-
with ModelPatcher.apply_lora_patches(model=model, patches=[(lora, 0.5)], prefix=""):
96+
with ModelPatcher.apply_model_patches(model=model, patches=[(lora, 0.5)], prefix=""):
9797
# After patching, all LoRA layer weights should have been moved back to the cpu.
9898
assert lora_layers["linear_layer_1"].up.device.type == "cpu"
9999
assert lora_layers["linear_layer_1"].down.device.type == "cpu"
@@ -146,7 +146,7 @@ def test_apply_lora_sidecar_patches(device: str, num_layers: int):
146146
output_before_patch = model(input)
147147

148148
# Patch the model and run inference during the patch.
149-
with ModelPatcher.apply_lora_sidecar_patches(model=model, patches=lora_models, prefix="", dtype=dtype):
149+
with ModelPatcher.apply_model_sidecar_patches(model=model, patches=lora_models, prefix="", dtype=dtype):
150150
output_during_patch = model(input)
151151

152152
# Run inference after unpatching.
@@ -186,10 +186,10 @@ def test_apply_lora_sidecar_patches_matches_apply_lora_patches(num_layers: int):
186186

187187
input = torch.randn(1, linear_in_features, device="cpu", dtype=dtype)
188188

189-
with ModelPatcher.apply_lora_patches(model=model, patches=lora_models, prefix=""):
189+
with ModelPatcher.apply_model_patches(model=model, patches=lora_models, prefix=""):
190190
output_lora_patches = model(input)
191191

192-
with ModelPatcher.apply_lora_sidecar_patches(model=model, patches=lora_models, prefix="", dtype=dtype):
192+
with ModelPatcher.apply_model_sidecar_patches(model=model, patches=lora_models, prefix="", dtype=dtype):
193193
output_lora_sidecar_patches = model(input)
194194

195195
# Note: We set atol=1e-5 because the test failed occasionally with the default atol=1e-8. Slight numerical

0 commit comments

Comments
 (0)