Skip to content

Commit 589a795

Browse files
author
Lincoln Stein
committed
fixup unit tests and remove debugging statements
1 parent e26360f commit 589a795

File tree

11 files changed

+61
-186
lines changed

11 files changed

+61
-186
lines changed

invokeai/app/api/dependencies.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
import torch
66

7-
import invokeai.backend.util.devices # horrible hack
87
from invokeai.app.services.object_serializer.object_serializer_disk import ObjectSerializerDisk
98
from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache
109
from invokeai.app.services.shared.sqlite.sqlite_util import init_db

invokeai/app/invocations/compel.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
9999
textual_inversion_manager=ti_manager,
100100
dtype_for_device_getter=TorchDevice.choose_torch_dtype,
101101
truncate_long_prompts=False,
102+
device=TorchDevice.choose_torch_device(),
102103
)
103104

104105
conjunction = Compel.parse_prompt_string(self.prompt)
@@ -113,6 +114,7 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
113114
conditioning_data = ConditioningFieldData(conditionings=[BasicConditioningInfo(embeds=c)])
114115

115116
conditioning_name = context.conditioning.save(conditioning_data)
117+
116118
return ConditioningOutput(
117119
conditioning=ConditioningField(
118120
conditioning_name=conditioning_name,

invokeai/app/services/invocation_stats/invocation_stats_default.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,9 @@ def collect_stats(self, invocation: BaseInvocation, graph_execution_state_id: st
7474
)
7575
self._stats[graph_execution_state_id].add_node_execution_stats(node_stats)
7676

77-
def reset_stats(self):
78-
self._stats = {}
79-
self._cache_stats = {}
77+
def reset_stats(self, graph_execution_state_id: str):
78+
self._stats.pop(graph_execution_state_id)
79+
self._cache_stats.pop(graph_execution_state_id)
8080

8181
def get_stats(self, graph_execution_state_id: str) -> InvocationStatsSummary:
8282
graph_stats_summary = self._get_graph_summary(graph_execution_state_id)

invokeai/app/services/model_manager/model_manager_default.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,6 @@ def build_model_manager(
7676

7777
ram_cache = ModelCache(
7878
max_cache_size=app_config.ram,
79-
max_vram_cache_size=app_config.vram,
80-
lazy_offloading=app_config.lazy_offload,
8179
logger=logger,
8280
)
8381
convert_cache = ModelConvertCache(cache_path=app_config.convert_cache_path, max_size=app_config.convert_cache)

invokeai/app/services/session_processor/session_processor_default.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import traceback
22
from contextlib import suppress
33
from queue import Queue
4-
from threading import BoundedSemaphore, Thread, Lock
4+
from threading import BoundedSemaphore, Lock, Thread
55
from threading import Event as ThreadEvent
66
from typing import Optional, Set
77

@@ -61,7 +61,9 @@ def __init__(
6161
self._on_after_run_session_callbacks = on_after_run_session_callbacks or []
6262
self._process_lock = Lock()
6363

64-
def start(self, services: InvocationServices, cancel_event: ThreadEvent, profiler: Optional[Profiler] = None) -> None:
64+
def start(
65+
self, services: InvocationServices, cancel_event: ThreadEvent, profiler: Optional[Profiler] = None
66+
) -> None:
6567
self._services = services
6668
self._cancel_event = cancel_event
6769
self._profiler = profiler
@@ -214,7 +216,7 @@ def _on_after_run_session(self, queue_item: SessionQueueItem) -> None:
214216
# we don't care about that - suppress the error.
215217
with suppress(GESStatsNotFoundError):
216218
self._services.performance_statistics.log_stats(queue_item.session.id)
217-
self._services.performance_statistics.reset_stats()
219+
self._services.performance_statistics.reset_stats(queue_item.session.id)
218220

219221
for callback in self._on_after_run_session_callbacks:
220222
callback(queue_item=queue_item)
@@ -384,7 +386,6 @@ def start(self, invoker: Invoker) -> None:
384386
)
385387
worker.start()
386388

387-
388389
def stop(self, *args, **kwargs) -> None:
389390
self._stop_event.set()
390391

@@ -465,7 +466,7 @@ def _process(
465466
# Run the graph
466467
# self.session_runner.run(queue_item=self._queue_item)
467468

468-
except Exception as e:
469+
except Exception:
469470
# Wait for next polling interval or event to try again
470471
poll_now_event.wait(self._polling_interval)
471472
continue
@@ -494,7 +495,7 @@ def _process_next_session(self) -> None:
494495
with self._invoker.services.model_manager.load.ram_cache.reserve_execution_device():
495496
# Run the session on the reserved GPU
496497
self.session_runner.run(queue_item=queue_item)
497-
except Exception as e:
498+
except Exception:
498499
continue
499500
finally:
500501
self._active_queue_items.remove(queue_item)

invokeai/app/services/session_queue/session_queue_common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,7 @@ def queue_item_dto_from_dict(cls, queue_item_dict: dict) -> "SessionQueueItemDTO
239239
def __hash__(self) -> int:
240240
return self.item_id
241241

242+
242243
class SessionQueueItemDTO(SessionQueueItemWithoutGraph):
243244
pass
244245

invokeai/app/services/shared/invocation_context.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,6 @@ def load(self, name: str) -> ConditioningFieldData:
325325
Returns:
326326
The loaded conditioning data.
327327
"""
328-
329328
return self._services.conditioning.load(name)
330329

331330

invokeai/backend/model_manager/load/model_cache/model_cache_base.py

Lines changed: 1 addition & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -43,26 +43,9 @@ def model(self) -> AnyModel:
4343

4444
@dataclass
4545
class CacheRecord(Generic[T]):
46-
"""
47-
Elements of the cache:
48-
49-
key: Unique key for each model, same as used in the models database.
50-
model: Model in memory.
51-
state_dict: A read-only copy of the model's state dict in RAM. It will be
52-
used as a template for creating a copy in the VRAM.
53-
size: Size of the model
54-
loaded: True if the model's state dict is currently in VRAM
55-
56-
Before a model is executed, the state_dict template is copied into VRAM,
57-
and then injected into the model. When the model is finished, the VRAM
58-
copy of the state dict is deleted, and the RAM version is reinjected
59-
into the model.
60-
"""
46+
"""Elements of the cache."""
6147

6248
key: str
63-
model: T
64-
device: torch.device
65-
state_dict: Optional[Dict[str, torch.Tensor]]
6649
size: int
6750
model: T
6851
loaded: bool = False
@@ -130,28 +113,12 @@ def get_execution_device(self) -> torch.device:
130113
"""
131114
pass
132115

133-
@property
134-
@abstractmethod
135-
def lazy_offloading(self) -> bool:
136-
"""Return true if the cache is configured to lazily offload models in VRAM."""
137-
pass
138-
139116
@property
140117
@abstractmethod
141118
def max_cache_size(self) -> float:
142119
"""Return true if the cache is configured to lazily offload models in VRAM."""
143120
pass
144121

145-
@abstractmethod
146-
def offload_unlocked_models(self, size_required: int) -> None:
147-
"""Offload from VRAM any models not actively in use."""
148-
pass
149-
150-
@abstractmethod
151-
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
152-
"""Move model into the indicated device."""
153-
pass
154-
155122
@property
156123
@abstractmethod
157124
def stats(self) -> Optional[CacheStats]:

0 commit comments

Comments
 (0)