Skip to content

feat: Add support for FP8 MLA on Hopper and Blackwell. #3190

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Apr 7, 2025

Conversation

bobboli
Copy link
Collaborator

@bobboli bobboli commented Apr 1, 2025

This PR adds FP8 MLA support on Hopper and Blackwell.

  • Recipe: per-tensor FP8 e4m3 quantization for Q and latent KV. MLA output is in BF16. Currently no calibration and the quantization scales are simply set to 1.
  • Add code for Q and KV cache quantization.
  • trtllm-gen based FP8 MLA kernel support for Blackwell.
  • FP8 FlashMLA kernel support for Hopper. Default option for Hopper now.
  • (draft) FMHA based FP8 MLA kernel for Hopper, with accuracy issues to be investigated. This PR also includes some refactoring to the FMHA kernel management code.

DylanChen-NV and others added 5 commits April 1, 2025 06:55
Use BF16 for context MLA.
mFP8GenerationMLA and mFP8ContextFMHA shouldn't be enabled together.

Allow mSM==90 for mFP8GenerationMLA==true.
For FMHA, dataTypeKv should be FP8.

For FP8 MLA generation, the output is still in BF16.

Refine debug info for FMHA kernel metadata.

Use inputType, outputType, SM together to hash kernel list.

Add FP8 MLA generation FMHA kernel.

Special WAR of NUM_COMPUTE_GROUPS for MLA generation kernel.

Separate the implementation of fused_multihead_attention_v2.h to CPP and print some debug info if checkIfKernelExist fails.

Refine debug info in fused_multihead_attention_v2.cpp

Correct FP8 MLA metadata.

New kernel provided by Yuxin, which outputs BF16.

smem size is not set correctly, which will lead to illegal mem access.

Yuxin fixed the error in FMHA MLA kernel: previously the BF16 isn't correctly written: some parts are repeatedly written, while some others are untouched.

There are two bmm1 scales that should be set correctly.

New kernel generated by Yuxin.

Modificatiosn to common/attentionOp for FP8 MLA on Hopper using FMHA.

Not necessary. If mFP8GenerationMLA, is_fp8_out is false, so mFP8ContextFMHA is false.

Skip a check in fmhaDispatcher.

Modifications in fmhaRunner:
- Debug dump.
- if (!isFP8GenerationMLA) skips a lot of flag setting.
- TMA descriptor modification for qo (by Yuxin).

Cleanup debug output.

Clean up o tma descriptor modifications.

Signed-off-by: Bo Li <[email protected]>
Signed-off-by: Bo Li <[email protected]>
Signed-off-by: Bo Li <[email protected]>
@bobboli
Copy link
Collaborator Author

bobboli commented Apr 1, 2025

/bot run

@bobboli bobboli requested review from Tracin and DylanChen-NV April 1, 2025 07:01
@tensorrt-cicd
Copy link
Collaborator

PR_Github #869 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #869 [ run ] completed with state FAILURE
/LLM/main/L0_MergeRequest_PR pipeline #693 completed with status: 'FAILURE'

@@ -531,7 +531,9 @@ def forward_context(
# Concat q(including q_pe), k + k_pe, v together as input_qkv
input_qkv = torch.cat([q, k, v], dim=-1)

out_scale = getattr(self.o_proj, "inv_input_scale", None)
# out_scale = getattr(self.o_proj, "inv_input_scale", None)
out_scale = None # Currently we use BF16 MHA for context phase
Copy link
Collaborator Author

@bobboli bobboli Apr 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If out_scale is not None, attentionOp will assume that the attention output type is FP8. Currently we want to keep context MLA, as well as the output of generation MLA, in Bf16.

@bobboli
Copy link
Collaborator Author

bobboli commented Apr 1, 2025

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #889 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #889 [ run ] completed with state FAILURE
/LLM/main/L0_MergeRequest_PR pipeline #703 completed with status: 'FAILURE'

Signed-off-by: Bo Li <[email protected]>
@bobboli
Copy link
Collaborator Author

bobboli commented Apr 1, 2025

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #912 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #912 [ run ] completed with state FAILURE
/LLM/main/L0_MergeRequest_PR pipeline #720 completed with status: 'FAILURE'

@bobboli
Copy link
Collaborator Author

bobboli commented Apr 2, 2025

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #956 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #956 [ run ] completed with state FAILURE
/LLM/main/L0_MergeRequest_PR pipeline #747 completed with status: 'FAILURE'

@bobboli
Copy link
Collaborator Author

bobboli commented Apr 2, 2025

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #957 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #957 [ run ] completed with state FAILURE
/LLM/main/L0_MergeRequest_PR pipeline #748 completed with status: 'FAILURE'

@bobboli
Copy link
Collaborator Author

bobboli commented Apr 3, 2025

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #1133 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #1133 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #853 completed with status: 'FAILURE'

@bobboli
Copy link
Collaborator Author

bobboli commented Apr 6, 2025

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #1230 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #1230 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #923 completed with status: 'FAILURE'

@bobboli
Copy link
Collaborator Author

bobboli commented Apr 7, 2025

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #1248 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #1248 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #940 completed with status: 'SUCCESS'

@QiJune
Copy link
Collaborator

QiJune commented Apr 7, 2025

/bot reuse-pipeline

@QiJune QiJune enabled auto-merge (squash) April 7, 2025 07:00
@tensorrt-cicd
Copy link
Collaborator

PR_Github #1281 [ reuse-pipeline ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #1281 [ reuse-pipeline ] completed with state SUCCESS
Reusing PR_Github #1248 for commit efda97d

@QiJune QiJune merged commit 515dd0d into NVIDIA:main Apr 7, 2025
2 checks passed
sarattha pushed a commit to sarattha/TensorRT-LLM that referenced this pull request Apr 9, 2025
* fp8 kv + bf16 ctx MLA + fp8 gen MLA

Use BF16 for context MLA.
mFP8GenerationMLA and mFP8ContextFMHA shouldn't be enabled together.

Allow mSM==90 for mFP8GenerationMLA==true.
For FMHA, dataTypeKv should be FP8.

For FP8 MLA generation, the output is still in BF16.

Refine debug info for FMHA kernel metadata.

Use inputType, outputType, SM together to hash kernel list.

Add FP8 MLA generation FMHA kernel.

Special WAR of NUM_COMPUTE_GROUPS for MLA generation kernel.

Separate the implementation of fused_multihead_attention_v2.h to CPP and print some debug info if checkIfKernelExist fails.

Refine debug info in fused_multihead_attention_v2.cpp

Correct FP8 MLA metadata.

New kernel provided by Yuxin, which outputs BF16.

smem size is not set correctly, which will lead to illegal mem access.

Yuxin fixed the error in FMHA MLA kernel: previously the BF16 isn't correctly written: some parts are repeatedly written, while some others are untouched.

There are two bmm1 scales that should be set correctly.

New kernel generated by Yuxin.

Modificatiosn to common/attentionOp for FP8 MLA on Hopper using FMHA.

Not necessary. If mFP8GenerationMLA, is_fp8_out is false, so mFP8ContextFMHA is false.

Skip a check in fmhaDispatcher.

Modifications in fmhaRunner:
- Debug dump.
- if (!isFP8GenerationMLA) skips a lot of flag setting.
- TMA descriptor modification for qo (by Yuxin).

Cleanup debug output.

Clean up o tma descriptor modifications.

Signed-off-by: Bo Li <[email protected]>

* Resolve conflicts.

Signed-off-by: Bo Li <[email protected]>

* Apply the patch of FP8 FlashMLA and resolve conflicts.

Signed-off-by: Bo Li <[email protected]>

* Fix compilation error.

Signed-off-by: Bo Li <[email protected]>

* Fix compile error.

Signed-off-by: Bo Li <[email protected]>

* pick blackwell support

Signed-off-by: Dylan Chen <[email protected]>

* Add copyright notice to fused_multihead_attention_v2.cpp.

Signed-off-by: Bo Li <[email protected]>

* Add license.

Signed-off-by: Bo Li <[email protected]>

* Add missing license.

Signed-off-by: Bo Li <[email protected]>

* Exclude building flashMLA kernels under sm90.

Signed-off-by: Bo Li <[email protected]>

* Revert "Exclude building flashMLA kernels under sm90."

    This reverts commit f0c859d.

Signed-off-by: Bo Li <[email protected]>

* Use macro to skip compiling FlashMLA for non sm90 targets.

Signed-off-by: Bo Li <[email protected]>

---------

Signed-off-by: Bo Li <[email protected]>
Signed-off-by: Dylan Chen <[email protected]>
Co-authored-by: Dylan Chen <[email protected]>
Co-authored-by: Dylan Chen <[email protected]>
Co-authored-by: QI JUN <[email protected]>
Signed-off-by: sarattha <[email protected]>
tomeras91 pushed a commit to tomeras91/TensorRT-LLM that referenced this pull request Apr 9, 2025
* fp8 kv + bf16 ctx MLA + fp8 gen MLA

Use BF16 for context MLA.
mFP8GenerationMLA and mFP8ContextFMHA shouldn't be enabled together.

Allow mSM==90 for mFP8GenerationMLA==true.
For FMHA, dataTypeKv should be FP8.

For FP8 MLA generation, the output is still in BF16.

Refine debug info for FMHA kernel metadata.

Use inputType, outputType, SM together to hash kernel list.

Add FP8 MLA generation FMHA kernel.

Special WAR of NUM_COMPUTE_GROUPS for MLA generation kernel.

Separate the implementation of fused_multihead_attention_v2.h to CPP and print some debug info if checkIfKernelExist fails.

Refine debug info in fused_multihead_attention_v2.cpp

Correct FP8 MLA metadata.

New kernel provided by Yuxin, which outputs BF16.

smem size is not set correctly, which will lead to illegal mem access.

Yuxin fixed the error in FMHA MLA kernel: previously the BF16 isn't correctly written: some parts are repeatedly written, while some others are untouched.

There are two bmm1 scales that should be set correctly.

New kernel generated by Yuxin.

Modificatiosn to common/attentionOp for FP8 MLA on Hopper using FMHA.

Not necessary. If mFP8GenerationMLA, is_fp8_out is false, so mFP8ContextFMHA is false.

Skip a check in fmhaDispatcher.

Modifications in fmhaRunner:
- Debug dump.
- if (!isFP8GenerationMLA) skips a lot of flag setting.
- TMA descriptor modification for qo (by Yuxin).

Cleanup debug output.

Clean up o tma descriptor modifications.

Signed-off-by: Bo Li <[email protected]>

* Resolve conflicts.

Signed-off-by: Bo Li <[email protected]>

* Apply the patch of FP8 FlashMLA and resolve conflicts.

Signed-off-by: Bo Li <[email protected]>

* Fix compilation error.

Signed-off-by: Bo Li <[email protected]>

* Fix compile error.

Signed-off-by: Bo Li <[email protected]>

* pick blackwell support

Signed-off-by: Dylan Chen <[email protected]>

* Add copyright notice to fused_multihead_attention_v2.cpp.

Signed-off-by: Bo Li <[email protected]>

* Add license.

Signed-off-by: Bo Li <[email protected]>

* Add missing license.

Signed-off-by: Bo Li <[email protected]>

* Exclude building flashMLA kernels under sm90.

Signed-off-by: Bo Li <[email protected]>

* Revert "Exclude building flashMLA kernels under sm90."

    This reverts commit f0c859d.

Signed-off-by: Bo Li <[email protected]>

* Use macro to skip compiling FlashMLA for non sm90 targets.

Signed-off-by: Bo Li <[email protected]>

---------

Signed-off-by: Bo Li <[email protected]>
Signed-off-by: Dylan Chen <[email protected]>
Co-authored-by: Dylan Chen <[email protected]>
Co-authored-by: Dylan Chen <[email protected]>
Co-authored-by: QI JUN <[email protected]>
tomeras91 pushed a commit to tomeras91/TensorRT-LLM that referenced this pull request Apr 9, 2025
* fp8 kv + bf16 ctx MLA + fp8 gen MLA

Use BF16 for context MLA.
mFP8GenerationMLA and mFP8ContextFMHA shouldn't be enabled together.

Allow mSM==90 for mFP8GenerationMLA==true.
For FMHA, dataTypeKv should be FP8.

For FP8 MLA generation, the output is still in BF16.

Refine debug info for FMHA kernel metadata.

Use inputType, outputType, SM together to hash kernel list.

Add FP8 MLA generation FMHA kernel.

Special WAR of NUM_COMPUTE_GROUPS for MLA generation kernel.

Separate the implementation of fused_multihead_attention_v2.h to CPP and print some debug info if checkIfKernelExist fails.

Refine debug info in fused_multihead_attention_v2.cpp

Correct FP8 MLA metadata.

New kernel provided by Yuxin, which outputs BF16.

smem size is not set correctly, which will lead to illegal mem access.

Yuxin fixed the error in FMHA MLA kernel: previously the BF16 isn't correctly written: some parts are repeatedly written, while some others are untouched.

There are two bmm1 scales that should be set correctly.

New kernel generated by Yuxin.

Modificatiosn to common/attentionOp for FP8 MLA on Hopper using FMHA.

Not necessary. If mFP8GenerationMLA, is_fp8_out is false, so mFP8ContextFMHA is false.

Skip a check in fmhaDispatcher.

Modifications in fmhaRunner:
- Debug dump.
- if (!isFP8GenerationMLA) skips a lot of flag setting.
- TMA descriptor modification for qo (by Yuxin).

Cleanup debug output.

Clean up o tma descriptor modifications.

Signed-off-by: Bo Li <[email protected]>

* Resolve conflicts.

Signed-off-by: Bo Li <[email protected]>

* Apply the patch of FP8 FlashMLA and resolve conflicts.

Signed-off-by: Bo Li <[email protected]>

* Fix compilation error.

Signed-off-by: Bo Li <[email protected]>

* Fix compile error.

Signed-off-by: Bo Li <[email protected]>

* pick blackwell support

Signed-off-by: Dylan Chen <[email protected]>

* Add copyright notice to fused_multihead_attention_v2.cpp.

Signed-off-by: Bo Li <[email protected]>

* Add license.

Signed-off-by: Bo Li <[email protected]>

* Add missing license.

Signed-off-by: Bo Li <[email protected]>

* Exclude building flashMLA kernels under sm90.

Signed-off-by: Bo Li <[email protected]>

* Revert "Exclude building flashMLA kernels under sm90."

    This reverts commit f0c859d.

Signed-off-by: Bo Li <[email protected]>

* Use macro to skip compiling FlashMLA for non sm90 targets.

Signed-off-by: Bo Li <[email protected]>

---------

Signed-off-by: Bo Li <[email protected]>
Signed-off-by: Dylan Chen <[email protected]>
Co-authored-by: Dylan Chen <[email protected]>
Co-authored-by: Dylan Chen <[email protected]>
Co-authored-by: QI JUN <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants