Enable Flash Attention for SD3 MMDiT #2014
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR utilizes
ops.dot_product_attention
to accelerate inference in SD3I noticed that
ops.dot_product_attention
performed slower than the vanilla impl in the tensorflow backend. Therefore, this optimization path is skipped for it.(vanilla: 10.55s vs.
ops.dot_product_attention
: 14.33s)EDITED:
jax now runs faster than
diffusers
in an out-of-box manner:diffusers.StableDiffusion3Pipeline
: 6.15sThe benchmark script (KerasHub):
The benchmark script (
diffusers
):