Skip to content

Commit d3916db

Browse files
authored
Partial Loading PR1: Tidy ModelCache (#7492)
## Summary This PR tidies up the model cache code in preparation for further refactoring to support partial loading of models onto the GPU. **These code changes should not change the functional behavior in any way.** Changes: - Remove the `ModelCacheBase` class. `ModelCache` is the only implementation, so there is no benefit to the separate abstract class. - Split `CacheRecord` and `CacheStats` out into their own files. - Remove the `ModelLocker` class. This extra layer of indirection was not providing any benefit. Locking is now done directly with the `ModelCache`. - Tidy up relative imports that were contributing to circular import issues. - Pull the 'submodel' concern out of the `ModelCache`. The `ModelCache` should not need to be aware of the model manager submodel system. - Delete unused properties from the `ModelCache` (e.g. `.lazy_offloading`, `.storage_device`, etc.) ## QA Instructions I ran smoke tests with a variety of SD1, SDXL and FLUX models. No change to behavior is expected. ## Merge Plan <!--WHEN APPLICABLE: Large PRs, or PRs that touch sensitive things like DB schemas, may need some care when merging. For example, a careful rebase by the change author, timing to not interfere with a pending release, or a message to contributors on discord after merging.--> ## Checklist - [x] _The PR has a short but descriptive title, suitable for a changelog_ - [x] _Tests added / updated (if applicable)_ - [x] _Documentation added / updated (if applicable)_ - [ ] _Updated `What's New` copy (if doing a release after this PR)_
2 parents 712674b + 55b13c1 commit d3916db

File tree

18 files changed

+197
-438
lines changed

18 files changed

+197
-438
lines changed

docs/contributing/MODEL_MANAGER.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1364,7 +1364,6 @@ the in-memory loaded model:
13641364
|----------------|-----------------|------------------|
13651365
| `config` | AnyModelConfig | A copy of the model's configuration record for retrieving base type, etc. |
13661366
| `model` | AnyModel | The instantiated model (details below) |
1367-
| `locker` | ModelLockerBase | A context manager that mediates the movement of the model into VRAM |
13681367

13691368
### get_model_by_key(key, [submodel]) -> LoadedModel
13701369

invokeai/app/api/routers/model_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
ModelFormat,
3838
ModelType,
3939
)
40-
from invokeai.backend.model_manager.load.model_cache.model_cache_base import CacheStats
40+
from invokeai.backend.model_manager.load.model_cache.cache_stats import CacheStats
4141
from invokeai.backend.model_manager.metadata.fetch.huggingface import HuggingFaceMetadataFetch
4242
from invokeai.backend.model_manager.metadata.metadata_base import ModelMetadataWithFiles, UnknownMetadataException
4343
from invokeai.backend.model_manager.search import ModelSearch

invokeai/app/services/invocation_stats/invocation_stats_default.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
NodeExecutionStatsSummary,
2121
)
2222
from invokeai.app.services.invoker import Invoker
23-
from invokeai.backend.model_manager.load.model_cache import CacheStats
23+
from invokeai.backend.model_manager.load.model_cache.cache_stats import CacheStats
2424

2525
# Size of 1GB in bytes.
2626
GB = 2**30

invokeai/app/services/model_load/model_load_base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
99
from invokeai.backend.model_manager.load import LoadedModel, LoadedModelWithoutConfig
10-
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
10+
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
1111

1212

1313
class ModelLoadServiceBase(ABC):
@@ -24,7 +24,7 @@ def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubMo
2424

2525
@property
2626
@abstractmethod
27-
def ram_cache(self) -> ModelCacheBase[AnyModel]:
27+
def ram_cache(self) -> ModelCache:
2828
"""Return the RAM cache used by this loader."""
2929

3030
@abstractmethod

invokeai/app/services/model_load/model_load_default.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
ModelLoaderRegistry,
1919
ModelLoaderRegistryBase,
2020
)
21-
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
21+
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
2222
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
2323
from invokeai.backend.util.devices import TorchDevice
2424
from invokeai.backend.util.logging import InvokeAILogger
@@ -30,7 +30,7 @@ class ModelLoadService(ModelLoadServiceBase):
3030
def __init__(
3131
self,
3232
app_config: InvokeAIAppConfig,
33-
ram_cache: ModelCacheBase[AnyModel],
33+
ram_cache: ModelCache,
3434
registry: Optional[Type[ModelLoaderRegistryBase]] = ModelLoaderRegistry,
3535
):
3636
"""Initialize the model load service."""
@@ -45,7 +45,7 @@ def start(self, invoker: Invoker) -> None:
4545
self._invoker = invoker
4646

4747
@property
48-
def ram_cache(self) -> ModelCacheBase[AnyModel]:
48+
def ram_cache(self) -> ModelCache:
4949
"""Return the RAM cache used by this loader."""
5050
return self._ram_cache
5151

@@ -78,9 +78,8 @@ def load_model_from_path(
7878
self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None
7979
) -> LoadedModelWithoutConfig:
8080
cache_key = str(model_path)
81-
ram_cache = self.ram_cache
8281
try:
83-
return LoadedModelWithoutConfig(_locker=ram_cache.get(key=cache_key))
82+
return LoadedModelWithoutConfig(cache_record=self._ram_cache.get(key=cache_key), cache=self._ram_cache)
8483
except IndexError:
8584
pass
8685

@@ -109,5 +108,5 @@ def diffusers_load_directory(directory: Path) -> AnyModel:
109108
)
110109
assert loader is not None
111110
raw_model = loader(model_path)
112-
ram_cache.put(key=cache_key, model=raw_model)
113-
return LoadedModelWithoutConfig(_locker=ram_cache.get(key=cache_key))
111+
self._ram_cache.put(key=cache_key, model=raw_model)
112+
return LoadedModelWithoutConfig(cache_record=self._ram_cache.get(key=cache_key), cache=self._ram_cache)

invokeai/app/services/model_manager/model_manager_default.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
from invokeai.app.services.model_load.model_load_default import ModelLoadService
1717
from invokeai.app.services.model_manager.model_manager_base import ModelManagerServiceBase
1818
from invokeai.app.services.model_records.model_records_base import ModelRecordServiceBase
19-
from invokeai.backend.model_manager.load import ModelCache, ModelLoaderRegistry
19+
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
20+
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
2021
from invokeai.backend.util.devices import TorchDevice
2122
from invokeai.backend.util.logging import InvokeAILogger
2223

invokeai/backend/model_manager/load/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from invokeai.backend.model_manager.load.load_base import LoadedModel, LoadedModelWithoutConfig, ModelLoaderBase
1010
from invokeai.backend.model_manager.load.load_default import ModelLoader
11-
from invokeai.backend.model_manager.load.model_cache.model_cache_default import ModelCache
11+
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
1212
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry, ModelLoaderRegistryBase
1313

1414
# This registers the subclasses that implement loaders of specific model types

invokeai/backend/model_manager/load/load_base.py

Lines changed: 23 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
from abc import ABC, abstractmethod
77
from contextlib import contextmanager
8-
from dataclasses import dataclass
98
from logging import Logger
109
from pathlib import Path
1110
from typing import Any, Dict, Generator, Optional, Tuple
@@ -18,19 +17,17 @@
1817
AnyModelConfig,
1918
SubModelType,
2019
)
21-
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
20+
from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord
21+
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
2222

2323

24-
@dataclass
2524
class LoadedModelWithoutConfig:
26-
"""
27-
Context manager object that mediates transfer from RAM<->VRAM.
25+
"""Context manager object that mediates transfer from RAM<->VRAM.
2826
2927
This is a context manager object that has two distinct APIs:
3028
3129
1. Older API (deprecated):
32-
Use the LoadedModel object directly as a context manager.
33-
It will move the model into VRAM (on CUDA devices), and
30+
Use the LoadedModel object directly as a context manager. It will move the model into VRAM (on CUDA devices), and
3431
return the model in a form suitable for passing to torch.
3532
Example:
3633
```
@@ -40,13 +37,9 @@ class LoadedModelWithoutConfig:
4037
```
4138
4239
2. Newer API (recommended):
43-
Call the LoadedModel's `model_on_device()` method in a
44-
context. It returns a tuple consisting of a copy of
45-
the model's state dict in CPU RAM followed by a copy
46-
of the model in VRAM. The state dict is provided to allow
47-
LoRAs and other model patchers to return the model to
48-
its unpatched state without expensive copy and restore
49-
operations.
40+
Call the LoadedModel's `model_on_device()` method in a context. It returns a tuple consisting of a copy of the
41+
model's state dict in CPU RAM followed by a copy of the model in VRAM. The state dict is provided to allow LoRAs and
42+
other model patchers to return the model to its unpatched state without expensive copy and restore operations.
5043
5144
Example:
5245
```
@@ -55,43 +48,42 @@ class LoadedModelWithoutConfig:
5548
image = vae.decode(latents)[0]
5649
```
5750
58-
The state_dict should be treated as a read-only object and
59-
never modified. Also be aware that some loadable models do
60-
not have a state_dict, in which case this value will be None.
51+
The state_dict should be treated as a read-only object and never modified. Also be aware that some loadable models
52+
do not have a state_dict, in which case this value will be None.
6153
"""
6254

63-
_locker: ModelLockerBase
55+
def __init__(self, cache_record: CacheRecord, cache: ModelCache):
56+
self._cache_record = cache_record
57+
self._cache = cache
6458

6559
def __enter__(self) -> AnyModel:
66-
"""Context entry."""
67-
self._locker.lock()
60+
self._cache.lock(self._cache_record.key)
6861
return self.model
6962

7063
def __exit__(self, *args: Any, **kwargs: Any) -> None:
71-
"""Context exit."""
72-
self._locker.unlock()
64+
self._cache.unlock(self._cache_record.key)
7365

7466
@contextmanager
7567
def model_on_device(self) -> Generator[Tuple[Optional[Dict[str, torch.Tensor]], AnyModel], None, None]:
7668
"""Return a tuple consisting of the model's state dict (if it exists) and the locked model on execution device."""
77-
locked_model = self._locker.lock()
69+
self._cache.lock(self._cache_record.key)
7870
try:
79-
state_dict = self._locker.get_state_dict()
80-
yield (state_dict, locked_model)
71+
yield (self._cache_record.state_dict, self._cache_record.model)
8172
finally:
82-
self._locker.unlock()
73+
self._cache.unlock(self._cache_record.key)
8374

8475
@property
8576
def model(self) -> AnyModel:
8677
"""Return the model without locking it."""
87-
return self._locker.model
78+
return self._cache_record.model
8879

8980

90-
@dataclass
9181
class LoadedModel(LoadedModelWithoutConfig):
9282
"""Context manager object that mediates transfer from RAM<->VRAM."""
9383

94-
config: Optional[AnyModelConfig] = None
84+
def __init__(self, config: Optional[AnyModelConfig], cache_record: CacheRecord, cache: ModelCache):
85+
super().__init__(cache_record=cache_record, cache=cache)
86+
self.config = config
9587

9688

9789
# TODO(MM2):
@@ -110,7 +102,7 @@ def __init__(
110102
self,
111103
app_config: InvokeAIAppConfig,
112104
logger: Logger,
113-
ram_cache: ModelCacheBase[AnyModel],
105+
ram_cache: ModelCache,
114106
):
115107
"""Initialize the loader."""
116108
pass
@@ -138,6 +130,6 @@ def get_size_fs(
138130

139131
@property
140132
@abstractmethod
141-
def ram_cache(self) -> ModelCacheBase[AnyModel]:
133+
def ram_cache(self) -> ModelCache:
142134
"""Return the ram cache associated with this loader."""
143135
pass

invokeai/backend/model_manager/load/load_default.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
)
1515
from invokeai.backend.model_manager.config import DiffusersConfigBase
1616
from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase
17-
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
17+
from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord
18+
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache, get_model_cache_key
1819
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs
1920
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
2021
from invokeai.backend.util.devices import TorchDevice
@@ -28,7 +29,7 @@ def __init__(
2829
self,
2930
app_config: InvokeAIAppConfig,
3031
logger: Logger,
31-
ram_cache: ModelCacheBase[AnyModel],
32+
ram_cache: ModelCache,
3233
):
3334
"""Initialize the loader."""
3435
self._app_config = app_config
@@ -54,22 +55,22 @@ def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubMo
5455
raise InvalidModelConfigException(f"Files for model '{model_config.name}' not found at {model_path}")
5556

5657
with skip_torch_weight_init():
57-
locker = self._load_and_cache(model_config, submodel_type)
58-
return LoadedModel(config=model_config, _locker=locker)
58+
cache_record = self._load_and_cache(model_config, submodel_type)
59+
return LoadedModel(config=model_config, cache_record=cache_record, cache=self._ram_cache)
5960

6061
@property
61-
def ram_cache(self) -> ModelCacheBase[AnyModel]:
62+
def ram_cache(self) -> ModelCache:
6263
"""Return the ram cache associated with this loader."""
6364
return self._ram_cache
6465

6566
def _get_model_path(self, config: AnyModelConfig) -> Path:
6667
model_base = self._app_config.models_path
6768
return (model_base / config.path).resolve()
6869

69-
def _load_and_cache(self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> ModelLockerBase:
70+
def _load_and_cache(self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> CacheRecord:
7071
stats_name = ":".join([config.base, config.type, config.name, (submodel_type or "")])
7172
try:
72-
return self._ram_cache.get(config.key, submodel_type, stats_name=stats_name)
73+
return self._ram_cache.get(key=get_model_cache_key(config.key, submodel_type), stats_name=stats_name)
7374
except IndexError:
7475
pass
7576

@@ -78,16 +79,11 @@ def _load_and_cache(self, config: AnyModelConfig, submodel_type: Optional[SubMod
7879
loaded_model = self._load_model(config, submodel_type)
7980

8081
self._ram_cache.put(
81-
config.key,
82-
submodel_type=submodel_type,
82+
get_model_cache_key(config.key, submodel_type),
8383
model=loaded_model,
8484
)
8585

86-
return self._ram_cache.get(
87-
key=config.key,
88-
submodel_type=submodel_type,
89-
stats_name=stats_name,
90-
)
86+
return self._ram_cache.get(key=get_model_cache_key(config.key, submodel_type), stats_name=stats_name)
9187

9288
def get_size_fs(
9389
self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None
Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +0,0 @@
1-
"""Init file for ModelCache."""
2-
3-
from .model_cache_base import ModelCacheBase, CacheStats # noqa F401
4-
from .model_cache_default import ModelCache # noqa F401
5-
6-
_all__ = ["ModelCacheBase", "ModelCache", "CacheStats"]
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from dataclasses import dataclass
2+
from typing import Any, Dict, Optional
3+
4+
import torch
5+
6+
7+
@dataclass
8+
class CacheRecord:
9+
"""
10+
Elements of the cache:
11+
12+
key: Unique key for each model, same as used in the models database.
13+
model: Model in memory.
14+
state_dict: A read-only copy of the model's state dict in RAM. It will be
15+
used as a template for creating a copy in the VRAM.
16+
size: Size of the model
17+
loaded: True if the model's state dict is currently in VRAM
18+
19+
Before a model is executed, the state_dict template is copied into VRAM,
20+
and then injected into the model. When the model is finished, the VRAM
21+
copy of the state dict is deleted, and the RAM version is reinjected
22+
into the model.
23+
24+
The state_dict should be treated as a read-only attribute. Do not attempt
25+
to patch or otherwise modify it. Instead, patch the copy of the state_dict
26+
after it is loaded into the execution device (e.g. CUDA) using the `LoadedModel`
27+
context manager call `model_on_device()`.
28+
"""
29+
30+
key: str
31+
model: Any
32+
device: torch.device
33+
state_dict: Optional[Dict[str, torch.Tensor]]
34+
size: int
35+
loaded: bool = False
36+
_locks: int = 0
37+
38+
def lock(self) -> None:
39+
"""Lock this record."""
40+
self._locks += 1
41+
42+
def unlock(self) -> None:
43+
"""Unlock this record."""
44+
self._locks -= 1
45+
assert self._locks >= 0
46+
47+
@property
48+
def locked(self) -> bool:
49+
"""Return true if record is locked."""
50+
return self._locks > 0
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from dataclasses import dataclass, field
2+
from typing import Dict
3+
4+
5+
@dataclass
6+
class CacheStats(object):
7+
"""Collect statistics on cache performance."""
8+
9+
hits: int = 0 # cache hits
10+
misses: int = 0 # cache misses
11+
high_watermark: int = 0 # amount of cache used
12+
in_cache: int = 0 # number of models in cache
13+
cleared: int = 0 # number of models cleared to make space
14+
cache_size: int = 0 # total size of cache
15+
loaded_model_sizes: Dict[str, int] = field(default_factory=dict)

0 commit comments

Comments
 (0)