Skip to content

Commit 3613ee3

Browse files
committed
Add endpoint for emptying the model cache. Also, adds a threading lock to the ModelCache to make it thread-safe.
1 parent debcbd6 commit 3613ee3

File tree

3 files changed

+79
-1
lines changed

3 files changed

+79
-1
lines changed

invokeai/app/api/routers/model_manager.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -858,6 +858,18 @@ async def get_stats() -> Optional[CacheStats]:
858858
return ApiDependencies.invoker.services.model_manager.load.ram_cache.stats
859859

860860

861+
@model_manager_router.post(
862+
"/empty_model_cache",
863+
operation_id="empty_model_cache",
864+
status_code=200,
865+
)
866+
async def empty_model_cache() -> None:
867+
"""Drop all models from the model cache to free RAM/VRAM. 'Locked' models that are in active use will not be dropped."""
868+
# Request 1000GB of room in order to force the cache to drop all models.
869+
ApiDependencies.invoker.services.logger.info("Emptying model cache.")
870+
ApiDependencies.invoker.services.model_manager.load.ram_cache.make_room(1000 * 2**30)
871+
872+
861873
class HFTokenStatus(str, Enum):
862874
VALID = "valid"
863875
INVALID = "invalid"

invokeai/backend/model_manager/load/model_cache/model_cache.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import gc
22
import logging
3+
import threading
34
import time
5+
from functools import wraps
46
from logging import Logger
5-
from typing import Dict, List, Optional
7+
from typing import Any, Callable, Dict, List, Optional
68

79
import psutil
810
import torch
@@ -41,6 +43,17 @@ def get_model_cache_key(model_key: str, submodel_type: Optional[SubModelType] =
4143
return model_key
4244

4345

46+
def synchronized(method: Callable[..., Any]) -> Callable[..., Any]:
47+
"""A decorator that applies the class's self._lock to the method."""
48+
49+
@wraps(method)
50+
def wrapper(self, *args, **kwargs):
51+
with self._lock: # Automatically acquire and release the lock
52+
return method(self, *args, **kwargs)
53+
54+
return wrapper
55+
56+
4457
class ModelCache:
4558
"""A cache for managing models in memory.
4659
@@ -125,16 +138,25 @@ def __init__(
125138

126139
self._ram_cache_size_bytes = self._calc_ram_available_to_model_cache()
127140

141+
# A lock applied to all public method calls to make the ModelCache thread-safe.
142+
# At the time of writing, the ModelCache should only be accessed from two threads:
143+
# - The graph execution thread
144+
# - Requests to empty the cache from a separate thread
145+
self._lock = threading.RLock()
146+
128147
@property
148+
@synchronized
129149
def stats(self) -> Optional[CacheStats]:
130150
"""Return collected CacheStats object."""
131151
return self._stats
132152

133153
@stats.setter
154+
@synchronized
134155
def stats(self, stats: CacheStats) -> None:
135156
"""Set the CacheStats object for collecting cache statistics."""
136157
self._stats = stats
137158

159+
@synchronized
138160
def put(self, key: str, model: AnyModel) -> None:
139161
"""Add a model to the cache."""
140162
if key in self._cached_models:
@@ -173,6 +195,7 @@ def put(self, key: str, model: AnyModel) -> None:
173195
f"Added model {key} (Type: {model.__class__.__name__}, Wrap mode: {wrapped_model.__class__.__name__}, Model size: {size/MB:.2f}MB)"
174196
)
175197

198+
@synchronized
176199
def get(self, key: str, stats_name: Optional[str] = None) -> CacheRecord:
177200
"""Retrieve a model from the cache.
178201
@@ -208,6 +231,7 @@ def get(self, key: str, stats_name: Optional[str] = None) -> CacheRecord:
208231
self._logger.debug(f"Cache hit: {key} (Type: {cache_entry.cached_model.model.__class__.__name__})")
209232
return cache_entry
210233

234+
@synchronized
211235
def lock(self, cache_entry: CacheRecord, working_mem_bytes: Optional[int]) -> None:
212236
"""Lock a model for use and move it into VRAM."""
213237
if cache_entry.key not in self._cached_models:
@@ -243,6 +267,7 @@ def lock(self, cache_entry: CacheRecord, working_mem_bytes: Optional[int]) -> No
243267

244268
self._log_cache_state()
245269

270+
@synchronized
246271
def unlock(self, cache_entry: CacheRecord) -> None:
247272
"""Unlock a model."""
248273
if cache_entry.key not in self._cached_models:
@@ -588,6 +613,7 @@ def _log_cache_state(self, title: str = "Model cache state:", include_entry_deta
588613

589614
self._logger.debug(log)
590615

616+
@synchronized
591617
def make_room(self, bytes_needed: int) -> None:
592618
"""Make enough room in the cache to accommodate a new model of indicated size.
593619

invokeai/frontend/web/src/services/api/schema.ts

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,26 @@ export type paths = {
320320
patch?: never;
321321
trace?: never;
322322
};
323+
"/api/v2/models/empty_model_cache": {
324+
parameters: {
325+
query?: never;
326+
header?: never;
327+
path?: never;
328+
cookie?: never;
329+
};
330+
get?: never;
331+
put?: never;
332+
/**
333+
* Empty Model Cache
334+
* @description Drop all models from the model cache to free RAM/VRAM. 'Locked' models that are in active use will not be dropped.
335+
*/
336+
post: operations["empty_model_cache"];
337+
delete?: never;
338+
options?: never;
339+
head?: never;
340+
patch?: never;
341+
trace?: never;
342+
};
323343
"/api/v2/models/hf_login": {
324344
parameters: {
325345
query?: never;
@@ -20327,6 +20347,26 @@ export interface operations {
2032720347
};
2032820348
};
2032920349
};
20350+
empty_model_cache: {
20351+
parameters: {
20352+
query?: never;
20353+
header?: never;
20354+
path?: never;
20355+
cookie?: never;
20356+
};
20357+
requestBody?: never;
20358+
responses: {
20359+
/** @description Successful Response */
20360+
200: {
20361+
headers: {
20362+
[name: string]: unknown;
20363+
};
20364+
content: {
20365+
"application/json": unknown;
20366+
};
20367+
};
20368+
};
20369+
};
2033020370
get_hf_login_status: {
2033120371
parameters: {
2033220372
query?: never;

0 commit comments

Comments
 (0)