|
1 | 1 | import gc
|
2 | 2 | import logging
|
| 3 | +import threading |
3 | 4 | import time
|
| 5 | +from functools import wraps |
4 | 6 | from logging import Logger
|
5 |
| -from typing import Dict, List, Optional |
| 7 | +from typing import Any, Callable, Dict, List, Optional |
6 | 8 |
|
7 | 9 | import psutil
|
8 | 10 | import torch
|
@@ -41,6 +43,17 @@ def get_model_cache_key(model_key: str, submodel_type: Optional[SubModelType] =
|
41 | 43 | return model_key
|
42 | 44 |
|
43 | 45 |
|
| 46 | +def synchronized(method: Callable[..., Any]) -> Callable[..., Any]: |
| 47 | + """A decorator that applies the class's self._lock to the method.""" |
| 48 | + |
| 49 | + @wraps(method) |
| 50 | + def wrapper(self, *args, **kwargs): |
| 51 | + with self._lock: # Automatically acquire and release the lock |
| 52 | + return method(self, *args, **kwargs) |
| 53 | + |
| 54 | + return wrapper |
| 55 | + |
| 56 | + |
44 | 57 | class ModelCache:
|
45 | 58 | """A cache for managing models in memory.
|
46 | 59 |
|
@@ -125,16 +138,25 @@ def __init__(
|
125 | 138 |
|
126 | 139 | self._ram_cache_size_bytes = self._calc_ram_available_to_model_cache()
|
127 | 140 |
|
| 141 | + # A lock applied to all public method calls to make the ModelCache thread-safe. |
| 142 | + # At the time of writing, the ModelCache should only be accessed from two threads: |
| 143 | + # - The graph execution thread |
| 144 | + # - Requests to empty the cache from a separate thread |
| 145 | + self._lock = threading.RLock() |
| 146 | + |
128 | 147 | @property
|
| 148 | + @synchronized |
129 | 149 | def stats(self) -> Optional[CacheStats]:
|
130 | 150 | """Return collected CacheStats object."""
|
131 | 151 | return self._stats
|
132 | 152 |
|
133 | 153 | @stats.setter
|
| 154 | + @synchronized |
134 | 155 | def stats(self, stats: CacheStats) -> None:
|
135 | 156 | """Set the CacheStats object for collecting cache statistics."""
|
136 | 157 | self._stats = stats
|
137 | 158 |
|
| 159 | + @synchronized |
138 | 160 | def put(self, key: str, model: AnyModel) -> None:
|
139 | 161 | """Add a model to the cache."""
|
140 | 162 | if key in self._cached_models:
|
@@ -173,6 +195,7 @@ def put(self, key: str, model: AnyModel) -> None:
|
173 | 195 | f"Added model {key} (Type: {model.__class__.__name__}, Wrap mode: {wrapped_model.__class__.__name__}, Model size: {size/MB:.2f}MB)"
|
174 | 196 | )
|
175 | 197 |
|
| 198 | + @synchronized |
176 | 199 | def get(self, key: str, stats_name: Optional[str] = None) -> CacheRecord:
|
177 | 200 | """Retrieve a model from the cache.
|
178 | 201 |
|
@@ -208,6 +231,7 @@ def get(self, key: str, stats_name: Optional[str] = None) -> CacheRecord:
|
208 | 231 | self._logger.debug(f"Cache hit: {key} (Type: {cache_entry.cached_model.model.__class__.__name__})")
|
209 | 232 | return cache_entry
|
210 | 233 |
|
| 234 | + @synchronized |
211 | 235 | def lock(self, cache_entry: CacheRecord, working_mem_bytes: Optional[int]) -> None:
|
212 | 236 | """Lock a model for use and move it into VRAM."""
|
213 | 237 | if cache_entry.key not in self._cached_models:
|
@@ -243,6 +267,7 @@ def lock(self, cache_entry: CacheRecord, working_mem_bytes: Optional[int]) -> No
|
243 | 267 |
|
244 | 268 | self._log_cache_state()
|
245 | 269 |
|
| 270 | + @synchronized |
246 | 271 | def unlock(self, cache_entry: CacheRecord) -> None:
|
247 | 272 | """Unlock a model."""
|
248 | 273 | if cache_entry.key not in self._cached_models:
|
@@ -588,6 +613,7 @@ def _log_cache_state(self, title: str = "Model cache state:", include_entry_deta
|
588 | 613 |
|
589 | 614 | self._logger.debug(log)
|
590 | 615 |
|
| 616 | + @synchronized |
591 | 617 | def make_room(self, bytes_needed: int) -> None:
|
592 | 618 | """Make enough room in the cache to accommodate a new model of indicated size.
|
593 | 619 |
|
|
0 commit comments