Skip to content

Commit 4b63bfc

Browse files
committed
Remove unnecessary hasattr checks for scaled_dot_product_attention. We pin the torch version, so there should be no concern that this function does not exist.
1 parent bbb9939 commit 4b63bfc

File tree

1 file changed

+5
-10
lines changed

1 file changed

+5
-10
lines changed

invokeai/backend/stable_diffusion/diffusers_pipeline.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -187,29 +187,24 @@ def _adjust_memory_efficient_attention(self, latents: torch.Tensor):
187187
self.disable_attention_slicing()
188188
return
189189
elif config.attention_type == "torch-sdp":
190-
if hasattr(torch.nn.functional, "scaled_dot_product_attention"):
191-
# diffusers enables sdp automatically
192-
return
193-
else:
194-
raise Exception("torch-sdp attention slicing not available")
190+
# torch-sdp is the default in diffusers.
191+
return
195192

196193
# See https://github.com/invoke-ai/InvokeAI/issues/7049 for context.
197194
# Bumping torch from 2.2.2 to 2.4.1 caused the sliced attention implementation to produce incorrect results.
198195
# For now, if a user is on an MPS device and has not explicitly set the attention_type, then we select the
199196
# non-sliced torch-sdp implementation. This keeps things working on MPS at the cost of increased peak memory
200197
# utilization.
201198
if torch.backends.mps.is_available():
202-
assert hasattr(torch.nn.functional, "scaled_dot_product_attention")
203199
return
204200

205-
# the remainder if this code is called when attention_type=='auto'
201+
# The remainder if this code is called when attention_type=='auto'.
206202
if self.unet.device.type == "cuda":
207203
if is_xformers_available():
208204
self.enable_xformers_memory_efficient_attention()
209205
return
210-
elif hasattr(torch.nn.functional, "scaled_dot_product_attention"):
211-
# diffusers enables sdp automatically
212-
return
206+
# torch-sdp is the default in diffusers.
207+
return
213208

214209
if self.unet.device.type == "cpu" or self.unet.device.type == "mps":
215210
mem_free = psutil.virtual_memory().free

0 commit comments

Comments
 (0)