-
Notifications
You must be signed in to change notification settings - Fork 1.4k
feat: Add support for FP8 MLA on Hopper and Blackwell. #3190
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Use BF16 for context MLA. mFP8GenerationMLA and mFP8ContextFMHA shouldn't be enabled together. Allow mSM==90 for mFP8GenerationMLA==true. For FMHA, dataTypeKv should be FP8. For FP8 MLA generation, the output is still in BF16. Refine debug info for FMHA kernel metadata. Use inputType, outputType, SM together to hash kernel list. Add FP8 MLA generation FMHA kernel. Special WAR of NUM_COMPUTE_GROUPS for MLA generation kernel. Separate the implementation of fused_multihead_attention_v2.h to CPP and print some debug info if checkIfKernelExist fails. Refine debug info in fused_multihead_attention_v2.cpp Correct FP8 MLA metadata. New kernel provided by Yuxin, which outputs BF16. smem size is not set correctly, which will lead to illegal mem access. Yuxin fixed the error in FMHA MLA kernel: previously the BF16 isn't correctly written: some parts are repeatedly written, while some others are untouched. There are two bmm1 scales that should be set correctly. New kernel generated by Yuxin. Modificatiosn to common/attentionOp for FP8 MLA on Hopper using FMHA. Not necessary. If mFP8GenerationMLA, is_fp8_out is false, so mFP8ContextFMHA is false. Skip a check in fmhaDispatcher. Modifications in fmhaRunner: - Debug dump. - if (!isFP8GenerationMLA) skips a lot of flag setting. - TMA descriptor modification for qo (by Yuxin). Cleanup debug output. Clean up o tma descriptor modifications. Signed-off-by: Bo Li <[email protected]>
Signed-off-by: Bo Li <[email protected]>
Signed-off-by: Bo Li <[email protected]>
Signed-off-by: Bo Li <[email protected]>
Signed-off-by: Bo Li <[email protected]>
/bot run |
PR_Github #869 [ run ] triggered by Bot |
PR_Github #869 [ run ] completed with state |
@@ -531,7 +531,9 @@ def forward_context( | |||
# Concat q(including q_pe), k + k_pe, v together as input_qkv | |||
input_qkv = torch.cat([q, k, v], dim=-1) | |||
|
|||
out_scale = getattr(self.o_proj, "inv_input_scale", None) | |||
# out_scale = getattr(self.o_proj, "inv_input_scale", None) | |||
out_scale = None # Currently we use BF16 MHA for context phase |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If out_scale is not None, attentionOp will assume that the attention output type is FP8. Currently we want to keep context MLA, as well as the output of generation MLA, in Bf16.
Signed-off-by: Dylan Chen <[email protected]>
Signed-off-by: Bo Li <[email protected]>
/bot run |
PR_Github #889 [ run ] triggered by Bot |
PR_Github #889 [ run ] completed with state |
Signed-off-by: Bo Li <[email protected]>
/bot run |
PR_Github #912 [ run ] triggered by Bot |
PR_Github #912 [ run ] completed with state |
Signed-off-by: Bo Li <[email protected]>
/bot run |
PR_Github #956 [ run ] triggered by Bot |
PR_Github #956 [ run ] completed with state |
Signed-off-by: Bo Li <[email protected]>
/bot run |
PR_Github #957 [ run ] triggered by Bot |
PR_Github #957 [ run ] completed with state |
This reverts commit f0c859d. Signed-off-by: Bo Li <[email protected]>
Signed-off-by: Bo Li <[email protected]>
/bot run |
PR_Github #1133 [ run ] triggered by Bot |
PR_Github #1133 [ run ] completed with state |
Signed-off-by: Bo Li <[email protected]>
/bot run |
PR_Github #1230 [ run ] triggered by Bot |
PR_Github #1230 [ run ] completed with state |
/bot run |
PR_Github #1248 [ run ] triggered by Bot |
PR_Github #1248 [ run ] completed with state |
/bot reuse-pipeline |
PR_Github #1281 [ reuse-pipeline ] triggered by Bot |
PR_Github #1281 [ reuse-pipeline ] completed with state |
* fp8 kv + bf16 ctx MLA + fp8 gen MLA Use BF16 for context MLA. mFP8GenerationMLA and mFP8ContextFMHA shouldn't be enabled together. Allow mSM==90 for mFP8GenerationMLA==true. For FMHA, dataTypeKv should be FP8. For FP8 MLA generation, the output is still in BF16. Refine debug info for FMHA kernel metadata. Use inputType, outputType, SM together to hash kernel list. Add FP8 MLA generation FMHA kernel. Special WAR of NUM_COMPUTE_GROUPS for MLA generation kernel. Separate the implementation of fused_multihead_attention_v2.h to CPP and print some debug info if checkIfKernelExist fails. Refine debug info in fused_multihead_attention_v2.cpp Correct FP8 MLA metadata. New kernel provided by Yuxin, which outputs BF16. smem size is not set correctly, which will lead to illegal mem access. Yuxin fixed the error in FMHA MLA kernel: previously the BF16 isn't correctly written: some parts are repeatedly written, while some others are untouched. There are two bmm1 scales that should be set correctly. New kernel generated by Yuxin. Modificatiosn to common/attentionOp for FP8 MLA on Hopper using FMHA. Not necessary. If mFP8GenerationMLA, is_fp8_out is false, so mFP8ContextFMHA is false. Skip a check in fmhaDispatcher. Modifications in fmhaRunner: - Debug dump. - if (!isFP8GenerationMLA) skips a lot of flag setting. - TMA descriptor modification for qo (by Yuxin). Cleanup debug output. Clean up o tma descriptor modifications. Signed-off-by: Bo Li <[email protected]> * Resolve conflicts. Signed-off-by: Bo Li <[email protected]> * Apply the patch of FP8 FlashMLA and resolve conflicts. Signed-off-by: Bo Li <[email protected]> * Fix compilation error. Signed-off-by: Bo Li <[email protected]> * Fix compile error. Signed-off-by: Bo Li <[email protected]> * pick blackwell support Signed-off-by: Dylan Chen <[email protected]> * Add copyright notice to fused_multihead_attention_v2.cpp. Signed-off-by: Bo Li <[email protected]> * Add license. Signed-off-by: Bo Li <[email protected]> * Add missing license. Signed-off-by: Bo Li <[email protected]> * Exclude building flashMLA kernels under sm90. Signed-off-by: Bo Li <[email protected]> * Revert "Exclude building flashMLA kernels under sm90." This reverts commit f0c859d. Signed-off-by: Bo Li <[email protected]> * Use macro to skip compiling FlashMLA for non sm90 targets. Signed-off-by: Bo Li <[email protected]> --------- Signed-off-by: Bo Li <[email protected]> Signed-off-by: Dylan Chen <[email protected]> Co-authored-by: Dylan Chen <[email protected]> Co-authored-by: Dylan Chen <[email protected]> Co-authored-by: QI JUN <[email protected]> Signed-off-by: sarattha <[email protected]>
* fp8 kv + bf16 ctx MLA + fp8 gen MLA Use BF16 for context MLA. mFP8GenerationMLA and mFP8ContextFMHA shouldn't be enabled together. Allow mSM==90 for mFP8GenerationMLA==true. For FMHA, dataTypeKv should be FP8. For FP8 MLA generation, the output is still in BF16. Refine debug info for FMHA kernel metadata. Use inputType, outputType, SM together to hash kernel list. Add FP8 MLA generation FMHA kernel. Special WAR of NUM_COMPUTE_GROUPS for MLA generation kernel. Separate the implementation of fused_multihead_attention_v2.h to CPP and print some debug info if checkIfKernelExist fails. Refine debug info in fused_multihead_attention_v2.cpp Correct FP8 MLA metadata. New kernel provided by Yuxin, which outputs BF16. smem size is not set correctly, which will lead to illegal mem access. Yuxin fixed the error in FMHA MLA kernel: previously the BF16 isn't correctly written: some parts are repeatedly written, while some others are untouched. There are two bmm1 scales that should be set correctly. New kernel generated by Yuxin. Modificatiosn to common/attentionOp for FP8 MLA on Hopper using FMHA. Not necessary. If mFP8GenerationMLA, is_fp8_out is false, so mFP8ContextFMHA is false. Skip a check in fmhaDispatcher. Modifications in fmhaRunner: - Debug dump. - if (!isFP8GenerationMLA) skips a lot of flag setting. - TMA descriptor modification for qo (by Yuxin). Cleanup debug output. Clean up o tma descriptor modifications. Signed-off-by: Bo Li <[email protected]> * Resolve conflicts. Signed-off-by: Bo Li <[email protected]> * Apply the patch of FP8 FlashMLA and resolve conflicts. Signed-off-by: Bo Li <[email protected]> * Fix compilation error. Signed-off-by: Bo Li <[email protected]> * Fix compile error. Signed-off-by: Bo Li <[email protected]> * pick blackwell support Signed-off-by: Dylan Chen <[email protected]> * Add copyright notice to fused_multihead_attention_v2.cpp. Signed-off-by: Bo Li <[email protected]> * Add license. Signed-off-by: Bo Li <[email protected]> * Add missing license. Signed-off-by: Bo Li <[email protected]> * Exclude building flashMLA kernels under sm90. Signed-off-by: Bo Li <[email protected]> * Revert "Exclude building flashMLA kernels under sm90." This reverts commit f0c859d. Signed-off-by: Bo Li <[email protected]> * Use macro to skip compiling FlashMLA for non sm90 targets. Signed-off-by: Bo Li <[email protected]> --------- Signed-off-by: Bo Li <[email protected]> Signed-off-by: Dylan Chen <[email protected]> Co-authored-by: Dylan Chen <[email protected]> Co-authored-by: Dylan Chen <[email protected]> Co-authored-by: QI JUN <[email protected]>
* fp8 kv + bf16 ctx MLA + fp8 gen MLA Use BF16 for context MLA. mFP8GenerationMLA and mFP8ContextFMHA shouldn't be enabled together. Allow mSM==90 for mFP8GenerationMLA==true. For FMHA, dataTypeKv should be FP8. For FP8 MLA generation, the output is still in BF16. Refine debug info for FMHA kernel metadata. Use inputType, outputType, SM together to hash kernel list. Add FP8 MLA generation FMHA kernel. Special WAR of NUM_COMPUTE_GROUPS for MLA generation kernel. Separate the implementation of fused_multihead_attention_v2.h to CPP and print some debug info if checkIfKernelExist fails. Refine debug info in fused_multihead_attention_v2.cpp Correct FP8 MLA metadata. New kernel provided by Yuxin, which outputs BF16. smem size is not set correctly, which will lead to illegal mem access. Yuxin fixed the error in FMHA MLA kernel: previously the BF16 isn't correctly written: some parts are repeatedly written, while some others are untouched. There are two bmm1 scales that should be set correctly. New kernel generated by Yuxin. Modificatiosn to common/attentionOp for FP8 MLA on Hopper using FMHA. Not necessary. If mFP8GenerationMLA, is_fp8_out is false, so mFP8ContextFMHA is false. Skip a check in fmhaDispatcher. Modifications in fmhaRunner: - Debug dump. - if (!isFP8GenerationMLA) skips a lot of flag setting. - TMA descriptor modification for qo (by Yuxin). Cleanup debug output. Clean up o tma descriptor modifications. Signed-off-by: Bo Li <[email protected]> * Resolve conflicts. Signed-off-by: Bo Li <[email protected]> * Apply the patch of FP8 FlashMLA and resolve conflicts. Signed-off-by: Bo Li <[email protected]> * Fix compilation error. Signed-off-by: Bo Li <[email protected]> * Fix compile error. Signed-off-by: Bo Li <[email protected]> * pick blackwell support Signed-off-by: Dylan Chen <[email protected]> * Add copyright notice to fused_multihead_attention_v2.cpp. Signed-off-by: Bo Li <[email protected]> * Add license. Signed-off-by: Bo Li <[email protected]> * Add missing license. Signed-off-by: Bo Li <[email protected]> * Exclude building flashMLA kernels under sm90. Signed-off-by: Bo Li <[email protected]> * Revert "Exclude building flashMLA kernels under sm90." This reverts commit f0c859d. Signed-off-by: Bo Li <[email protected]> * Use macro to skip compiling FlashMLA for non sm90 targets. Signed-off-by: Bo Li <[email protected]> --------- Signed-off-by: Bo Li <[email protected]> Signed-off-by: Dylan Chen <[email protected]> Co-authored-by: Dylan Chen <[email protected]> Co-authored-by: Dylan Chen <[email protected]> Co-authored-by: QI JUN <[email protected]>
This PR adds FP8 MLA support on Hopper and Blackwell.