Skip to content

Commit 594511c

Browse files
authored
Add FLUX Control LoRA weight param (#7452)
## Summary Add the ability to control the weight of a FLUX Control LoRA. ## Example Original image: <div style="display: flex; gap: 10px;"> <img src="https://github.com/user-attachments/assets/4a2d9f4a-b58b-4df6-af90-67b018763a38" alt="Image 1" width="300"/> </div> Prompt: `a scarecrow playing tennis` Weights: 0.4, 0.6, 0.8, 1.0 <div style="display: flex; gap: 10px;"> <img src="https://github.com/user-attachments/assets/62b83fd6-46ce-460a-8d51-9c2cda9b05c9" alt="Image 1" width="300"/> <img src="https://github.com/user-attachments/assets/75442207-1538-46bc-9d6b-08ac5c235c93" alt="Image 2" width="300"/> </div> <div style="display: flex; gap: 10px;"> <img src="https://github.com/user-attachments/assets/4a9dc9ea-9757-4965-837e-197fc9243007" alt="Image 1" width="300"/> <img src="https://github.com/user-attachments/assets/846f6918-ca82-4482-8c19-19172752fa8c" alt="Image 2" width="300"/> </div> ## QA Instructions - [x] weight control changes strength of control image - [x] Test that results match across both quantized and non-quantized. ## Merge Plan **_Do not merge this PR yet._** 1. Merge #7450 2. Merge #7446 3. Change target branch to main 4. Merge this branch. ## Checklist - [ ] _The PR has a short but descriptive title, suitable for a changelog_ - [ ] _Tests added / updated (if applicable)_ - [ ] _Documentation added / updated (if applicable)_ - [ ] _Updated `What's New` copy (if doing a release after this PR)_
2 parents 4d5f74c + d764aa4 commit 594511c

File tree

8 files changed

+60
-16
lines changed

8 files changed

+60
-16
lines changed

invokeai/app/invocations/flux_control_lora_loader.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ class FluxControlLoRALoaderOutput(BaseInvocationOutput):
2424
title="Flux Control LoRA",
2525
tags=["lora", "model", "flux"],
2626
category="model",
27-
version="1.0.0",
27+
version="1.1.0",
2828
classification=Classification.Prototype,
2929
)
3030
class FluxControlLoRALoaderInvocation(BaseInvocation):
@@ -34,6 +34,7 @@ class FluxControlLoRALoaderInvocation(BaseInvocation):
3434
description=FieldDescriptions.control_lora_model, title="Control LoRA", ui_type=UIType.ControlLoRAModel
3535
)
3636
image: ImageField = InputField(description="The image to encode.")
37+
weight: float = InputField(description="The weight of the LoRA.", default=1.0)
3738

3839
def invoke(self, context: InvocationContext) -> FluxControlLoRALoaderOutput:
3940
if not context.models.exists(self.lora.key):
@@ -43,6 +44,6 @@ def invoke(self, context: InvocationContext) -> FluxControlLoRALoaderOutput:
4344
control_lora=ControlLoRAField(
4445
lora=self.lora,
4546
img=self.image,
46-
weight=1,
47+
weight=self.weight,
4748
)
4849
)

invokeai/backend/patches/layers/set_parameter_layer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@ def __init__(self, param_name: str, weight: torch.Tensor):
1515
self.param_name = param_name
1616

1717
def get_parameters(self, orig_module: torch.nn.Module, weight: float) -> dict[str, torch.Tensor]:
18+
# Note: We intentionally ignore the weight parameter here. This matches the behavior in the official FLUX
19+
# Control LoRA implementation.
1820
diff = self.weight - orig_module.get_parameter(self.param_name)
19-
return {self.param_name: diff * weight}
21+
return {self.param_name: diff}
2022

2123
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
2224
self.weight = self.weight.to(device=device, dtype=dtype)

invokeai/frontend/web/src/features/controlLayers/components/ControlLayer/ControlLayerControlAdapter.tsx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ export const ControlLayerControlAdapter = memo(() => {
162162
/>
163163
<input {...uploadApi.getUploadInputProps()} />
164164
</Flex>
165-
{controlAdapter.type !== 'control_lora' && <Weight weight={controlAdapter.weight} onChange={onChangeWeight} />}
165+
<Weight weight={controlAdapter.weight} onChange={onChangeWeight} />
166166
{controlAdapter.type !== 'control_lora' && (
167167
<BeginEndStepPct beginEndStepPct={controlAdapter.beginEndStepPct} onChange={onChangeBeginEndStepPct} />
168168
)}

invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ import {
7474
getReferenceImageState,
7575
getRegionalGuidanceState,
7676
imageDTOToImageWithDims,
77+
initialControlLoRA,
7778
initialControlNet,
7879
initialIPAdapter,
7980
initialT2IAdapter,
@@ -462,38 +463,64 @@ export const canvasSlice = createSlice({
462463
}
463464
layer.controlAdapter.model = zModelIdentifierField.parse(modelConfig);
464465

466+
// When converting between control layer types, we may need to add or remove properties. For example, ControlNet
467+
// has a control mode, while T2I Adapter does not - otherwise they are the same.
468+
465469
switch (layer.controlAdapter.model.type) {
470+
// Converting to T2I adapter from...
466471
case 't2i_adapter': {
467472
if (layer.controlAdapter.type === 'controlnet') {
473+
// T2I Adapters have all the ControlNet properties, minus control mode - strip it
468474
const { controlMode: _, ...rest } = layer.controlAdapter;
469-
const t2iAdapterConfig: T2IAdapterConfig = { ...rest, type: 't2i_adapter' };
475+
const t2iAdapterConfig: T2IAdapterConfig = { ...initialT2IAdapter, ...rest, type: 't2i_adapter' };
470476
layer.controlAdapter = t2iAdapterConfig;
471477
} else if (layer.controlAdapter.type === 'control_lora') {
472-
const t2iAdapterConfig: T2IAdapterConfig = { ...layer.controlAdapter, ...initialT2IAdapter };
478+
// Control LoRAs have only model and weight
479+
const t2iAdapterConfig: T2IAdapterConfig = {
480+
...initialT2IAdapter,
481+
...layer.controlAdapter,
482+
type: 't2i_adapter',
483+
};
473484
layer.controlAdapter = t2iAdapterConfig;
474485
}
475486
break;
476487
}
477488

489+
// Converting to ControlNet from...
478490
case 'controlnet': {
479491
if (layer.controlAdapter.type === 't2i_adapter') {
492+
// ControlNets have all the T2I Adapter properties, plus control mode
480493
const controlNetConfig: ControlNetConfig = {
494+
...initialControlNet,
481495
...layer.controlAdapter,
482496
type: 'controlnet',
483-
controlMode: initialControlNet.controlMode,
484497
};
485498
layer.controlAdapter = controlNetConfig;
486499
} else if (layer.controlAdapter.type === 'control_lora') {
487-
const controlNetConfig: ControlNetConfig = { ...layer.controlAdapter, ...initialControlNet };
500+
// ControlNets have all the Control LoRA properties, plus control mode and begin/end step pct
501+
const controlNetConfig: ControlNetConfig = {
502+
...initialControlNet,
503+
...layer.controlAdapter,
504+
type: 'controlnet',
505+
};
488506
layer.controlAdapter = controlNetConfig;
489507
}
490508
break;
491509
}
492510

511+
// Converting to ControlLoRA from...
493512
case 'control_lora': {
494-
const controlLoraConfig: ControlLoRAConfig = { ...layer.controlAdapter, type: 'control_lora' };
495-
layer.controlAdapter = controlLoraConfig;
496-
513+
if (layer.controlAdapter.type === 'controlnet') {
514+
// We only need the model and weight for Control LoRA
515+
const { model, weight } = layer.controlAdapter;
516+
const controlNetConfig: ControlLoRAConfig = { ...initialControlLoRA, model, weight };
517+
layer.controlAdapter = controlNetConfig;
518+
} else if (layer.controlAdapter.type === 't2i_adapter') {
519+
// We only need the model and weight for Control LoRA
520+
const { model, weight } = layer.controlAdapter;
521+
const t2iAdapterConfig: ControlLoRAConfig = { ...initialControlLoRA, model, weight };
522+
layer.controlAdapter = t2iAdapterConfig;
523+
}
497524
break;
498525
}
499526

@@ -518,7 +545,7 @@ export const canvasSlice = createSlice({
518545
) => {
519546
const { entityIdentifier, weight } = action.payload;
520547
const layer = selectEntity(state, entityIdentifier);
521-
if (!layer || !layer.controlAdapter || layer.controlAdapter.type === 'control_lora') {
548+
if (!layer || !layer.controlAdapter) {
522549
return;
523550
}
524551
layer.controlAdapter.weight = weight;

invokeai/frontend/web/src/features/controlLayers/store/types.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,7 @@ export type T2IAdapterConfig = z.infer<typeof zT2IAdapterConfig>;
298298

299299
const zControlLoRAConfig = z.object({
300300
type: z.literal('control_lora'),
301+
weight: z.number().gte(-1).lte(2),
301302
model: zServerValidatedModelIdentifierField.nullable(),
302303
});
303304
export type ControlLoRAConfig = z.infer<typeof zControlLoRAConfig>;

invokeai/frontend/web/src/features/controlLayers/store/util.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import type {
77
CanvasRasterLayerState,
88
CanvasReferenceImageState,
99
CanvasRegionalGuidanceState,
10+
ControlLoRAConfig,
1011
ControlNetConfig,
1112
ImageWithDims,
1213
IPAdapterConfig,
@@ -82,6 +83,11 @@ export const initialControlNet: ControlNetConfig = {
8283
beginEndStepPct: [0, 0.75],
8384
controlMode: 'balanced',
8485
};
86+
export const initialControlLoRA: ControlLoRAConfig = {
87+
type: 'control_lora',
88+
model: null,
89+
weight: 0.75,
90+
};
8591

8692
export const getReferenceImageState = (
8793
id: string,

invokeai/frontend/web/src/features/nodes/util/graph/generation/addControlAdapters.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ const addControlLoRAToGraph = (
207207
) => {
208208
const { id, controlAdapter } = layer;
209209
assert(controlAdapter.type === 'control_lora');
210-
const { model } = controlAdapter;
210+
const { model, weight } = controlAdapter;
211211
assert(model !== null);
212212
const { image_name } = imageDTO;
213213

@@ -216,6 +216,7 @@ const addControlLoRAToGraph = (
216216
type: 'flux_control_lora_loader',
217217
lora: model,
218218
image: { image_name },
219+
weight: weight,
219220
});
220221

221222
g.addEdge(controlLoRA, 'control_lora', denoise, 'control_lora');

invokeai/frontend/web/src/services/api/schema.ts

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6708,6 +6708,12 @@ export type components = {
67086708
* @default null
67096709
*/
67106710
image?: components["schemas"]["ImageField"];
6711+
/**
6712+
* Weight
6713+
* @description The weight of the LoRA.
6714+
* @default 1
6715+
*/
6716+
weight?: number;
67116717
/**
67126718
* type
67136719
* @default flux_control_lora_loader
@@ -6722,11 +6728,11 @@ export type components = {
67226728
*/
67236729
FluxControlLoRALoaderOutput: {
67246730
/**
6725-
* Flux Control Lora
6731+
* Flux Control LoRA
67266732
* @description Control LoRAs to apply on model loading
67276733
* @default null
67286734
*/
6729-
control_lora: components["schemas"]["ControlLoRAField"] | null;
6735+
control_lora: components["schemas"]["ControlLoRAField"];
67306736
/**
67316737
* type
67326738
* @default flux_control_lora_loader_output
@@ -6926,7 +6932,7 @@ export type components = {
69266932
*/
69276933
transformer?: components["schemas"]["TransformerField"];
69286934
/**
6929-
* Control Lora
6935+
* Control LoRA
69306936
* @description Control LoRA model to load
69316937
* @default null
69326938
*/

0 commit comments

Comments
 (0)