Skip to content

Commit 8f02ce5

Browse files
perf(ui): cache image data & transparency mode during generation mode calculation
Perf boost and reduces the number of images we create on the backend.
1 parent f4b7c63 commit 8f02ce5

File tree

5 files changed

+120
-28
lines changed

5 files changed

+120
-28
lines changed

invokeai/frontend/web/src/features/controlLayers/konva/CanvasCacheModule.ts

Lines changed: 69 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,32 @@
11
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
22
import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
3+
import type { Transparency } from 'features/controlLayers/konva/util';
34
import { getPrefixedId } from 'features/controlLayers/konva/util';
45
import type { GenerationMode } from 'features/controlLayers/store/types';
56
import { LRUCache } from 'lru-cache';
67
import type { Logger } from 'roarr';
78

9+
type GetCacheEntryWithFallbackArg<T extends NonNullable<unknown>> = {
10+
cache: LRUCache<string, T>;
11+
key: string;
12+
getValue: () => Promise<T>;
13+
onHit?: (value: T) => void;
14+
onMiss?: () => void;
15+
};
16+
817
type CanvasCacheModuleConfig = {
918
/**
1019
* The maximum size of the image name cache.
1120
*/
1221
imageNameCacheSize: number;
22+
/**
23+
* The maximum size of the image data cache.
24+
*/
25+
imageDataCacheSize: number;
26+
/**
27+
* The maximum size of the transparency calculation cache.
28+
*/
29+
transparencyCalculationCacheSize: number;
1330
/**
1431
* The maximum size of the canvas element cache.
1532
*/
@@ -21,7 +38,9 @@ type CanvasCacheModuleConfig = {
2138
};
2239

2340
const DEFAULT_CONFIG: CanvasCacheModuleConfig = {
24-
imageNameCacheSize: 100,
41+
imageNameCacheSize: 1000,
42+
imageDataCacheSize: 32,
43+
transparencyCalculationCacheSize: 1000,
2544
canvasElementCacheSize: 32,
2645
generationModeCacheSize: 100,
2746
};
@@ -41,26 +60,38 @@ export class CanvasCacheModule extends CanvasModuleBase {
4160
config: CanvasCacheModuleConfig = DEFAULT_CONFIG;
4261

4362
/**
44-
* A cache for storing image names. Used as a cache for results of layer/canvas/entity exports. For example, when we
45-
* rasterize a layer and upload it to the server, we store the image name in this cache.
63+
* A cache for storing image names.
4664
*
47-
* The cache key is a hash of the exported entity's state and the export rect.
65+
* For example, the key might be a hash of a composite of entities with the uploaded image name as the value.
4866
*/
4967
imageNameCache = new LRUCache<string, string>({ max: this.config.imageNameCacheSize });
5068

5169
/**
52-
* A cache for storing canvas elements. Similar to the image name cache, but for canvas elements. The primary use is
53-
* for caching composite layers. For example, the canvas compositor module uses this to store the canvas elements for
54-
* individual raster layers when creating a composite of the layers.
70+
* A cache for storing canvas elements.
5571
*
56-
* The cache key is a hash of the exported entity's state and the export rect.
72+
* For example, the key might be a hash of a composite of entities with the canvas element as the value.
5773
*/
5874
canvasElementCache = new LRUCache<string, HTMLCanvasElement>({ max: this.config.canvasElementCacheSize });
75+
5976
/**
60-
* A cache for the generation mode calculation, which is fairly expensive.
77+
* A cache for image data objects.
6178
*
62-
* The cache key is a hash of all the objects that contribute to the generation mode calculation (e.g. the composite
63-
* raster layer, the composite inpaint mask, and bounding box), and the value is the generation mode.
79+
* For example, the key might be a hash of a composite of entities with the image data as the value.
80+
*/
81+
imageDataCache = new LRUCache<string, ImageData>({ max: this.config.imageDataCacheSize });
82+
83+
/**
84+
* A cache for transparency calculation results.
85+
*
86+
* For example, the key might be a hash of a composite of entities with the transparency as the value.
87+
*/
88+
transparencyCalculationCache = new LRUCache<string, Transparency>({ max: this.config.imageDataCacheSize });
89+
90+
/**
91+
* A cache for generation mode calculation results.
92+
*
93+
* For example, the key might be a hash of a composite of raster and inpaint mask entities with the generation mode
94+
* as the value.
6495
*/
6596
generationModeCache = new LRUCache<string, GenerationMode>({ max: this.config.generationModeCacheSize });
6697

@@ -75,6 +106,33 @@ export class CanvasCacheModule extends CanvasModuleBase {
75106
this.log.debug('Creating cache module');
76107
}
77108

109+
/**
110+
* A helper function for getting a cache entry with a fallback.
111+
* @param param0.cache The LRUCache to get the entry from.
112+
* @param param0.key The key to use to retrieve the entry.
113+
* @param param0.getValue An async function to generate the value if the entry is not in the cache.
114+
* @param param0.onHit An optional function to call when the entry is in the cache.
115+
* @param param0.onMiss An optional function to call when the entry is not in the cache.
116+
* @returns
117+
*/
118+
static getWithFallback = async <T extends NonNullable<unknown>>({
119+
cache,
120+
getValue,
121+
key,
122+
onHit,
123+
onMiss,
124+
}: GetCacheEntryWithFallbackArg<T>): Promise<T> => {
125+
let value = cache.get(key);
126+
if (value === undefined) {
127+
onMiss?.();
128+
value = await getValue();
129+
cache.set(key, value);
130+
} else {
131+
onHit?.(value);
132+
}
133+
return value;
134+
};
135+
78136
/**
79137
* Clears all caches.
80138
*/

invokeai/frontend/web/src/features/controlLayers/konva/CanvasCompositorModule.ts

Lines changed: 48 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import type { SerializableObject } from 'common/types';
22
import { withResultAsync } from 'common/util/result';
3+
import { CanvasCacheModule } from 'features/controlLayers/konva/CanvasCacheModule';
34
import type { CanvasEntityAdapter, CanvasEntityAdapterFromType } from 'features/controlLayers/konva/CanvasEntity/types';
45
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
56
import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
7+
import type { Transparency } from 'features/controlLayers/konva/util';
68
import {
79
canvasToBlob,
810
canvasToImageData,
@@ -415,6 +417,38 @@ export class CanvasCompositorModule extends CanvasModuleBase {
415417
return this.mergeByEntityIdentifiers(entityIdentifiers, false);
416418
};
417419

420+
/**
421+
* Calculates the transparency of the composite of the give adapters.
422+
* @param adapters The adapters to composite
423+
* @param rect The region to include in the composite
424+
* @param hash The hash to use for caching the result
425+
* @returns A promise that resolves to the transparency of the composite
426+
*/
427+
getTransparency = (adapters: CanvasEntityAdapter[], rect: Rect, hash: string): Promise<Transparency> => {
428+
const entityIdentifiers = adapters.map((adapter) => adapter.entityIdentifier);
429+
const logCtx = { entityIdentifiers, rect };
430+
return CanvasCacheModule.getWithFallback({
431+
cache: this.manager.cache.transparencyCalculationCache,
432+
key: hash,
433+
getValue: async () => {
434+
this.$isProcessing.set(true);
435+
const compositeInpaintMaskCanvas = this.getCompositeCanvas(adapters, rect);
436+
437+
const compositeInpaintMaskImageData = await CanvasCacheModule.getWithFallback({
438+
cache: this.manager.cache.imageDataCache,
439+
key: hash,
440+
getValue: () => Promise.resolve(canvasToImageData(compositeInpaintMaskCanvas)),
441+
onHit: () => this.log.trace(logCtx, 'Using cached image data'),
442+
onMiss: () => this.log.trace(logCtx, 'Calculating image data'),
443+
});
444+
445+
return getImageDataTransparency(compositeInpaintMaskImageData);
446+
},
447+
onHit: () => this.log.trace(logCtx, 'Using cached transparency'),
448+
onMiss: () => this.log.trace(logCtx, 'Calculating transparency'),
449+
});
450+
};
451+
418452
/**
419453
* Calculates the generation mode for the current canvas state. This is determined by the transparency of the
420454
* composite raster layer and composite inpaint mask:
@@ -433,11 +467,11 @@ export class CanvasCompositorModule extends CanvasModuleBase {
433467
*
434468
* @returns The generation mode
435469
*/
436-
getGenerationMode(): GenerationMode {
470+
getGenerationMode = async (): Promise<GenerationMode> => {
437471
const { rect } = this.manager.stateApi.getBbox();
438472

439-
const rasterAdapters = this.manager.compositor.getVisibleAdaptersOfType('raster_layer');
440-
const compositeRasterLayerHash = this.getCompositeHash(rasterAdapters, { rect });
473+
const rasterLayerAdapters = this.manager.compositor.getVisibleAdaptersOfType('raster_layer');
474+
const compositeRasterLayerHash = this.getCompositeHash(rasterLayerAdapters, { rect });
441475

442476
const inpaintMaskAdapters = this.manager.compositor.getVisibleAdaptersOfType('inpaint_mask');
443477
const compositeInpaintMaskHash = this.getCompositeHash(inpaintMaskAdapters, { rect });
@@ -452,17 +486,17 @@ export class CanvasCompositorModule extends CanvasModuleBase {
452486

453487
this.log.debug({ rect }, 'Calculating generation mode');
454488

455-
const compositeInpaintMaskCanvas = this.getCompositeCanvas(inpaintMaskAdapters, rect);
456-
this.$isProcessing.set(true);
457-
const compositeInpaintMaskImageData = canvasToImageData(compositeInpaintMaskCanvas);
458-
const compositeInpaintMaskTransparency = getImageDataTransparency(compositeInpaintMaskImageData);
459-
this.$isProcessing.set(false);
489+
const compositeRasterLayerTransparency = await this.getTransparency(
490+
rasterLayerAdapters,
491+
rect,
492+
compositeRasterLayerHash
493+
);
460494

461-
const compositeRasterLayerCanvas = this.getCompositeCanvas(rasterAdapters, rect);
462-
this.$isProcessing.set(true);
463-
const compositeRasterLayerImageData = canvasToImageData(compositeRasterLayerCanvas);
464-
const compositeRasterLayerTransparency = getImageDataTransparency(compositeRasterLayerImageData);
465-
this.$isProcessing.set(false);
495+
const compositeInpaintMaskTransparency = await this.getTransparency(
496+
inpaintMaskAdapters,
497+
rect,
498+
compositeInpaintMaskHash
499+
);
466500

467501
let generationMode: GenerationMode;
468502
if (compositeRasterLayerTransparency === 'FULLY_TRANSPARENT') {
@@ -482,7 +516,7 @@ export class CanvasCompositorModule extends CanvasModuleBase {
482516

483517
this.manager.cache.generationModeCache.set(hash, generationMode);
484518
return generationMode;
485-
}
519+
};
486520

487521
repr = () => {
488522
return {

invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ export const buildFLUXGraph = async (
3434
state: RootState,
3535
manager: CanvasManager
3636
): Promise<{ g: Graph; noise: Invocation<'noise' | 'flux_denoise'>; posCond: Invocation<'flux_text_encoder'> }> => {
37-
const generationMode = manager.compositor.getGenerationMode();
37+
const generationMode = await manager.compositor.getGenerationMode();
3838
log.debug({ generationMode }, 'Building FLUX graph');
3939

4040
const params = selectParamsSlice(state);

invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSD1Graph.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ export const buildSD1Graph = async (
3737
state: RootState,
3838
manager: CanvasManager
3939
): Promise<{ g: Graph; noise: Invocation<'noise'>; posCond: Invocation<'compel'> }> => {
40-
const generationMode = manager.compositor.getGenerationMode();
40+
const generationMode = await manager.compositor.getGenerationMode();
4141
log.debug({ generationMode }, 'Building SD1/SD2 graph');
4242

4343
const params = selectParamsSlice(state);

invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ export const buildSDXLGraph = async (
3737
state: RootState,
3838
manager: CanvasManager
3939
): Promise<{ g: Graph; noise: Invocation<'noise'>; posCond: Invocation<'sdxl_compel_prompt'> }> => {
40-
const generationMode = manager.compositor.getGenerationMode();
40+
const generationMode = await manager.compositor.getGenerationMode();
4141
log.debug({ generationMode }, 'Building SDXL graph');
4242

4343
const params = selectParamsSlice(state);

0 commit comments

Comments
 (0)