Skip to content

Add multi-gpu support #5997

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 41 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
6b991a5
add draft multi-gpu support
Feb 19, 2024
24d7328
Merge branch 'main' into lstein/feat/multi-gpu
Mar 22, 2024
eaa2c68
remove vram_cache and don't move VRAM models back into CPU
Mar 31, 2024
bd9b00a
fix(nodes): 100% cpu usage when processor paused
psychedelicious Mar 31, 2024
a1dcab9
remove references to vram_cache in tests
Mar 31, 2024
32d3e4d
feat(nodes): simplify processor loop with an early continue
psychedelicious Mar 31, 2024
9336a07
add locking around thread-critical sections
Mar 31, 2024
83356ec
fix merge conflicts
Mar 31, 2024
cef51ad
Merge branch 'psyche/fix/nodes/processor-cpu-usage' into lstein/feat/…
Mar 31, 2024
9df0980
parallel processing working on single-GPU, not tested on multi
Apr 1, 2024
eca29c4
added notes
Apr 1, 2024
3d69372
implement session-level reservation of gpus
Apr 1, 2024
9adb15f
working but filled with debug statements
Apr 1, 2024
7dd93cb
fix merge issues; likely nonfunctional
Apr 16, 2024
f7436f3
fixup config_default; patch TorchDevice to work dynamically
Apr 16, 2024
a84f305
revert object_serializer_forward_cache.py
Apr 16, 2024
bd83390
add tid to cache name to avoid non-safe uuid4 on windows
Apr 16, 2024
fb9b7fb
make object_serializer._new_name() thread-safe; add max_threads config
Apr 16, 2024
371f5bc
simplify logic for retrieving execution devices
Apr 16, 2024
77130f1
Merge branch 'main' into lstein/feat/multi-gpu
lstein Apr 16, 2024
99558de
device selection calls go through TorchDevice
Apr 16, 2024
89f8326
Merge branch 'lstein/feat/multi-gpu' of github.com:invoke-ai/InvokeAI…
Apr 16, 2024
eaadc55
make pause/resume work in multithreaded environment
Apr 16, 2024
763a2e2
added more unit tests
Apr 16, 2024
d04c880
fix ValueError on model manager install
Apr 16, 2024
edac01d
reverse stupid hack
Apr 16, 2024
84f5cbd
make choose_torch_dtype() usable outside an invocation context
Apr 16, 2024
c3d1252
revert to old system for doing RAM <-> VRAM transfers; new way leaks …
Apr 17, 2024
1c0067f
Merge branch 'main' into lstein/feat/multi-gpu
lstein Apr 30, 2024
e57809e
Merge branch 'main' into lstein/feat/multi-gpu
lstein May 3, 2024
debef24
Merge branch 'main' into lstein/feat/multi-gpu
lstein May 6, 2024
e26360f
merged multi-gpu support into new session_processor architecture
Jun 2, 2024
589a795
fixup unit tests and remove debugging statements
Jun 2, 2024
7088d56
add script to sync models db with models.yaml
Jun 16, 2024
0df018b
resolve merge conflicts
Jun 23, 2024
6932f27
fixup code broken by merge with main
Jun 23, 2024
2219e36
copy model from a meta device template
Jun 24, 2024
9b7b182
remove dangling attributes in ModelCache class
Jun 24, 2024
5d6a77d
fixup ip adapter handling
Jun 24, 2024
02957be
fix compel conditioning object caching issue by applying deepcopy() b…
Jul 18, 2024
9dcace7
ruff fixes and restore default map location of object serializer load
Jul 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/contributing/MODEL_MANAGER.md
Original file line number Diff line number Diff line change
Expand Up @@ -1328,7 +1328,7 @@ from invokeai.app.services.model_load import ModelLoadService, ModelLoaderRegist

config = InvokeAIAppConfig.get_config()
ram_cache = ModelCache(
max_cache_size=config.ram_cache_size, max_vram_cache_size=config.vram_cache_size, logger=logger
max_cache_size=config.ram_cache_size, logger=logger
)
convert_cache = ModelConvertCache(
cache_path=config.models_convert_cache_path, max_size=config.convert_cache_size
Expand Down
4 changes: 3 additions & 1 deletion invokeai/app/invocations/compel.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
textual_inversion_manager=ti_manager,
dtype_for_device_getter=TorchDevice.choose_torch_dtype,
truncate_long_prompts=False,
device=TorchDevice.choose_torch_device(),
)

conjunction = Compel.parse_prompt_string(self.prompt)
Expand All @@ -117,6 +118,7 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
conditioning_data = ConditioningFieldData(conditionings=[BasicConditioningInfo(embeds=c)])

conditioning_name = context.conditioning.save(conditioning_data)

return ConditioningOutput(
conditioning=ConditioningField(
conditioning_name=conditioning_name,
Expand Down Expand Up @@ -203,6 +205,7 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
truncate_long_prompts=False, # TODO:
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip
requires_pooled=get_pooled,
device=TorchDevice.choose_torch_device(),
)

conjunction = Compel.parse_prompt_string(prompt)
Expand Down Expand Up @@ -313,7 +316,6 @@ def invoke(self, context: InvocationContext) -> ConditioningOutput:
)
]
)

conditioning_name = context.conditioning.save(conditioning_data)

return ConditioningOutput(
Expand Down
5 changes: 3 additions & 2 deletions invokeai/app/invocations/denoise_latents.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
import copy
import inspect
from contextlib import ExitStack
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
Expand Down Expand Up @@ -193,9 +194,8 @@ def _get_text_embeddings_and_masks(
text_embeddings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]] = []
text_embeddings_masks: list[Optional[torch.Tensor]] = []
for cond in cond_list:
cond_data = context.conditioning.load(cond.conditioning_name)
cond_data = copy.deepcopy(context.conditioning.load(cond.conditioning_name))
text_embeddings.append(cond_data.conditionings[0].to(device=device, dtype=dtype))

mask = cond.mask
if mask is not None:
mask = context.tensors.load(mask.tensor_name)
Expand Down Expand Up @@ -226,6 +226,7 @@ def _preprocess_regional_prompt_mask(
# Add a batch dimension to the mask, because torchvision expects shape (batch, channels, h, w).
mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w)
resized_mask = tf(mask)
assert isinstance(resized_mask, torch.Tensor)
return resized_mask

def _concat_regional_text_embeddings(
Expand Down
42 changes: 34 additions & 8 deletions invokeai/app/services/config/config_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,13 @@
DEFAULT_RAM_CACHE = 10.0
DEFAULT_VRAM_CACHE = 0.25
DEFAULT_CONVERT_CACHE = 20.0
DEVICE = Literal["auto", "cpu", "cuda", "cuda:1", "mps"]
PRECISION = Literal["auto", "float16", "bfloat16", "float32"]
DEVICE = Literal["auto", "cpu", "cuda:0", "cuda:1", "cuda:2", "cuda:3", "cuda:4", "cuda:5", "cuda:6", "cuda:7", "mps"]
PRECISION = Literal["auto", "float16", "bfloat16", "float32", "autocast"]
ATTENTION_TYPE = Literal["auto", "normal", "xformers", "sliced", "torch-sdp"]
ATTENTION_SLICE_SIZE = Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8]
LOG_FORMAT = Literal["plain", "color", "syslog", "legacy"]
LOG_LEVEL = Literal["debug", "info", "warning", "error", "critical"]
CONFIG_SCHEMA_VERSION = "4.0.1"
CONFIG_SCHEMA_VERSION = "4.0.2"


def get_default_ram_cache_size() -> float:
Expand Down Expand Up @@ -105,14 +105,16 @@ class InvokeAIAppConfig(BaseSettings):
convert_cache: Maximum size of on-disk converted models cache (GB).
lazy_offload: Keep models in VRAM until their space is needed.
log_memory_usage: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.
device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps`
precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.<br>Valid values: `auto`, `float16`, `bfloat16`, `float32`
device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda:0`, `cuda:1`, `cuda:2`, `cuda:3`, `cuda:4`, `cuda:5`, `cuda:6`, `cuda:7`, `mps`
devices: List of execution devices; will override default device selected.
precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.<br>Valid values: `auto`, `float16`, `bfloat16`, `float32`, `autocast`
sequential_guidance: Whether to calculate guidance in serial instead of in parallel, lowering memory requirements.
attention_type: Attention type.<br>Valid values: `auto`, `normal`, `xformers`, `sliced`, `torch-sdp`
attention_slice_size: Slice size, valid when attention_type=="sliced".<br>Valid values: `auto`, `balanced`, `max`, `1`, `2`, `3`, `4`, `5`, `6`, `7`, `8`
force_tiled_decode: Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).
pil_compress_level: The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.
max_queue_size: Maximum number of items in the session queue.
max_threads: Maximum number of session queue execution threads. Autocalculated from number of GPUs if not set.
clear_queue_on_startup: Empties session queue on startup.
allow_nodes: List of nodes to allow. Omit to allow all.
deny_nodes: List of nodes to deny. Omit to deny none.
Expand Down Expand Up @@ -178,6 +180,7 @@ class InvokeAIAppConfig(BaseSettings):

# DEVICE
device: DEVICE = Field(default="auto", description="Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.")
devices: Optional[list[DEVICE]] = Field(default=None, description="List of execution devices; will override default device selected.")
precision: PRECISION = Field(default="auto", description="Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.")

# GENERATION
Expand All @@ -187,6 +190,7 @@ class InvokeAIAppConfig(BaseSettings):
force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).")
pil_compress_level: int = Field(default=1, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.")
max_queue_size: int = Field(default=10000, gt=0, description="Maximum number of items in the session queue.")
max_threads: Optional[int] = Field(default=None, description="Maximum number of session queue execution threads. Autocalculated from number of GPUs if not set.")
clear_queue_on_startup: bool = Field(default=False, description="Empties session queue on startup.")

# NODES
Expand Down Expand Up @@ -376,9 +380,6 @@ def migrate_v3_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
# `max_cache_size` was renamed to `ram` some time in v3, but both names were used
if k == "max_cache_size" and "ram" not in category_dict:
parsed_config_dict["ram"] = v
# `max_vram_cache_size` was renamed to `vram` some time in v3, but both names were used
if k == "max_vram_cache_size" and "vram" not in category_dict:
parsed_config_dict["vram"] = v
# autocast was removed in v4.0.1
if k == "precision" and v == "autocast":
parsed_config_dict["precision"] = "auto"
Expand Down Expand Up @@ -426,6 +427,27 @@ def migrate_v4_0_0_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig
return config


def migrate_v4_0_1_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
"""Migrate v4.0.1 config dictionary to a current config object.

A few new multi-GPU options were added in 4.0.2, and this simply
updates the schema label.

Args:
config_dict: A dictionary of settings from a v4.0.1 config file.

Returns:
An instance of `InvokeAIAppConfig` with the migrated settings.
"""
parsed_config_dict: dict[str, Any] = {}
for k, _ in config_dict.items():
if k == "schema_version":
parsed_config_dict[k] = CONFIG_SCHEMA_VERSION
config = DefaultInvokeAIAppConfig.model_validate(parsed_config_dict)
return config


# TO DO: replace this with a formal registration and migration system
def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
"""Load and migrate a config file to the latest version.

Expand Down Expand Up @@ -457,6 +479,10 @@ def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
loaded_config_dict = migrate_v4_0_0_config_dict(loaded_config_dict)
loaded_config_dict.write_file(config_path)

elif loaded_config_dict["schema_version"] == "4.0.1":
loaded_config_dict = migrate_v4_0_1_config_dict(loaded_config_dict)
loaded_config_dict.write_file(config_path)

# Attempt to load as a v4 config file
try:
# Meta is not included in the model fields, so we need to validate it separately
Expand Down
4 changes: 2 additions & 2 deletions invokeai/app/services/invocation_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,11 @@ def __init__(
model_images: "ModelImageFileStorageBase",
model_manager: "ModelManagerServiceBase",
download_queue: "DownloadQueueServiceBase",
performance_statistics: "InvocationStatsServiceBase",
session_queue: "SessionQueueBase",
session_processor: "SessionProcessorBase",
invocation_cache: "InvocationCacheBase",
names: "NameServiceBase",
performance_statistics: "InvocationStatsServiceBase",
urls: "UrlServiceBase",
workflow_records: "WorkflowRecordsStorageBase",
tensors: "ObjectSerializerBase[torch.Tensor]",
Expand All @@ -77,11 +77,11 @@ def __init__(
self.model_images = model_images
self.model_manager = model_manager
self.download_queue = download_queue
self.performance_statistics = performance_statistics
self.session_queue = session_queue
self.session_processor = session_processor
self.invocation_cache = invocation_cache
self.names = names
self.performance_statistics = performance_statistics
self.urls = urls
self.workflow_records = workflow_records
self.tensors = tensors
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,9 @@ def collect_stats(self, invocation: BaseInvocation, graph_execution_state_id: st
)
self._stats[graph_execution_state_id].add_node_execution_stats(node_stats)

def reset_stats(self):
self._stats = {}
self._cache_stats = {}
def reset_stats(self, graph_execution_state_id: str):
self._stats.pop(graph_execution_state_id)
self._cache_stats.pop(graph_execution_state_id)

def get_stats(self, graph_execution_state_id: str) -> InvocationStatsSummary:
graph_stats_summary = self._get_graph_summary(graph_execution_state_id)
Expand Down
10 changes: 8 additions & 2 deletions invokeai/app/services/model_install/model_install_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,9 +284,14 @@ def prune_jobs(self) -> None:
unfinished_jobs = [x for x in self._install_jobs if not x.in_terminal_state]
self._install_jobs = unfinished_jobs

def _migrate_yaml(self) -> None:
def _migrate_yaml(self, rename_yaml: Optional[bool] = True, overwrite_db: Optional[bool] = False) -> None:
db_models = self.record_store.all_models()

if overwrite_db:
for model in db_models:
self.record_store.del_model(model.key)
db_models = self.record_store.all_models()

legacy_models_yaml_path = (
self._app_config.legacy_models_yaml_path or self._app_config.root_path / "configs" / "models.yaml"
)
Expand Down Expand Up @@ -336,7 +341,8 @@ def _migrate_yaml(self) -> None:
self._logger.warning(f"Model at {model_path} could not be migrated: {e}")

# Rename `models.yaml` to `models.yaml.bak` to prevent re-migration
legacy_models_yaml_path.rename(legacy_models_yaml_path.with_suffix(".yaml.bak"))
if rename_yaml:
legacy_models_yaml_path.rename(legacy_models_yaml_path.with_suffix(".yaml.bak"))

# Unset the path - we are done with it either way
self._app_config.legacy_models_yaml_path = None
Expand Down
5 changes: 5 additions & 0 deletions invokeai/app/services/model_load/model_load_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ def ram_cache(self) -> ModelCacheBase[AnyModel]:
def convert_cache(self) -> ModelConvertCacheBase:
"""Return the checkpoint convert cache used by this loader."""

@property
@abstractmethod
def gpu_count(self) -> int:
"""Return the number of GPUs we are configured to use."""

@abstractmethod
def load_model_from_path(
self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None
Expand Down
6 changes: 6 additions & 0 deletions invokeai/app/services/model_load/model_load_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,19 @@ def __init__(
self._registry = registry

def start(self, invoker: Invoker) -> None:
"""Start the service."""
self._invoker = invoker

@property
def ram_cache(self) -> ModelCacheBase[AnyModel]:
"""Return the RAM cache used by this loader."""
return self._ram_cache

@property
def gpu_count(self) -> int:
"""Return the number of GPUs available for our uses."""
return len(self._ram_cache.execution_devices)

@property
def convert_cache(self) -> ModelConvertCacheBase:
"""Return the checkpoint convert cache used by this loader."""
Expand Down
3 changes: 2 additions & 1 deletion invokeai/app/services/model_manager/model_manager_base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team

from abc import ABC, abstractmethod
from typing import Optional, Set

import torch
from typing_extensions import Self
Expand Down Expand Up @@ -31,7 +32,7 @@ def build_model_manager(
model_record_service: ModelRecordServiceBase,
download_queue: DownloadQueueServiceBase,
events: EventServiceBase,
execution_device: torch.device,
execution_devices: Optional[Set[torch.device]] = None,
) -> Self:
"""
Construct the model manager service instance.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
"""Implementation of ModelManagerServiceBase."""

from typing import Optional

import torch
from typing_extensions import Self

from invokeai.app.services.invoker import Invoker
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger

from ..config import InvokeAIAppConfig
Expand Down Expand Up @@ -69,7 +65,6 @@ def build_model_manager(
model_record_service: ModelRecordServiceBase,
download_queue: DownloadQueueServiceBase,
events: EventServiceBase,
execution_device: Optional[torch.device] = None,
) -> Self:
"""
Construct the model manager service instance.
Expand All @@ -82,9 +77,7 @@ def build_model_manager(
ram_cache = ModelCache(
max_cache_size=app_config.ram,
max_vram_cache_size=app_config.vram,
lazy_offloading=app_config.lazy_offload,
logger=logger,
execution_device=execution_device or TorchDevice.choose_torch_device(),
)
convert_cache = ModelConvertCache(cache_path=app_config.convert_cache_path, max_size=app_config.convert_cache)
loader = ModelLoadService(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import shutil
import tempfile
import threading
import typing
from pathlib import Path
from typing import TYPE_CHECKING, Optional, TypeVar
Expand All @@ -9,6 +10,7 @@
from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase
from invokeai.app.services.object_serializer.object_serializer_common import ObjectNotFoundError
from invokeai.app.util.misc import uuid_string
from invokeai.backend.util.devices import TorchDevice

if TYPE_CHECKING:
from invokeai.app.services.invoker import Invoker
Expand Down Expand Up @@ -70,7 +72,10 @@ def _get_path(self, name: str) -> Path:
return self._output_dir / name

def _new_name(self) -> str:
return f"{self._obj_class_name}_{uuid_string()}"
tid = threading.current_thread().ident
# Add tid to the object name because uuid4 not thread-safe on windows
# See https://stackoverflow.com/questions/2759644/python-multiprocessing-doesnt-play-nicely-with-uuid-uuid4
return f"{self._obj_class_name}_{tid}-{uuid_string()}"

def _tempdir_cleanup(self) -> None:
"""Calls `cleanup` on the temporary directory, if it exists."""
Expand Down
Loading