Skip to content

Commit 2678b4c

Browse files
RyanJDickpsychedelicious
authored andcommitted
Remove unnecessary hasattr checks for scaled_dot_product_attention. We pin the torch version, so there should be no concern that this function does not exist.
1 parent 889e572 commit 2678b4c

File tree

1 file changed

+5
-10
lines changed

1 file changed

+5
-10
lines changed

invokeai/backend/stable_diffusion/diffusers_pipeline.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -198,29 +198,24 @@ def _adjust_memory_efficient_attention(self, latents: torch.Tensor):
198198
self.disable_attention_slicing()
199199
return
200200
elif config.attention_type == "torch-sdp":
201-
if hasattr(torch.nn.functional, "scaled_dot_product_attention"):
202-
# diffusers enables sdp automatically
203-
return
204-
else:
205-
raise Exception("torch-sdp attention slicing not available")
201+
# torch-sdp is the default in diffusers.
202+
return
206203

207204
# See https://github.com/invoke-ai/InvokeAI/issues/7049 for context.
208205
# Bumping torch from 2.2.2 to 2.4.1 caused the sliced attention implementation to produce incorrect results.
209206
# For now, if a user is on an MPS device and has not explicitly set the attention_type, then we select the
210207
# non-sliced torch-sdp implementation. This keeps things working on MPS at the cost of increased peak memory
211208
# utilization.
212209
if torch.backends.mps.is_available():
213-
assert hasattr(torch.nn.functional, "scaled_dot_product_attention")
214210
return
215211

216-
# the remainder if this code is called when attention_type=='auto'
212+
# The remainder if this code is called when attention_type=='auto'.
217213
if self.unet.device.type == "cuda":
218214
if is_xformers_available() and prefer_xformers:
219215
self.enable_xformers_memory_efficient_attention()
220216
return
221-
elif hasattr(torch.nn.functional, "scaled_dot_product_attention"):
222-
# diffusers enables sdp automatically
223-
return
217+
# torch-sdp is the default in diffusers.
218+
return
224219

225220
if self.unet.device.type == "cpu" or self.unet.device.type == "mps":
226221
mem_free = psutil.virtual_memory().free

0 commit comments

Comments
 (0)