Skip to content

feat: Add support for FP8 MLA on Hopper and Blackwell. #3190

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Apr 7, 2025
Merged
117 changes: 99 additions & 18 deletions cpp/tensorrt_llm/common/attentionOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "tensorrt_llm/kernels/flashMLA/flash_mla.h"
#include "tensorrt_llm/kernels/gptKernels.h"
#include "tensorrt_llm/kernels/kvCacheUtils.h"
#include "tensorrt_llm/kernels/multiHeadAttentionCommon.h"
#include "tensorrt_llm/kernels/unfusedAttentionKernels.h"
#include "tensorrt_llm/runtime/iBuffer.h"
#include "tensorrt_llm/runtime/utils/debugUtils.h"
Expand Down Expand Up @@ -778,21 +779,24 @@ size_t AttentionOp::getWorkspaceSizeForGeneration(nvinfer1::DataType type, int32

size_t cu_seqlens_size = sizeof(int) * (max_num_seq + 1);
size_t fmha_scheduler_counter = sizeof(uint32_t);
size_t headDim = mMLAParams.kv_lora_rank + mMLAParams.qk_rope_head_dim;

int const NUM_BUFFERS = 7;
int const NUM_BUFFERS = 10;
size_t workspaces[NUM_BUFFERS];
workspaces[0] = cu_seqlens_size; // cu_q_len
workspaces[1] = cu_seqlens_size; // cu_kv_len
workspaces[0] = cu_seqlens_size; // cu_q_len
workspaces[1] = cu_seqlens_size; // cu_kv_len
workspaces[2] = fmha_scheduler_counter;
workspaces[3] = mFP8GenerationMLA ? sizeof(float) * 2 : 0; // mla_bmm1_scale_size
workspaces[4] = mFP8GenerationMLA ? sizeof(float) : 0; // mla_bmm2_scale_size
workspaces[5] = mFP8GenerationMLA ? max_num_tokens * size_t(mNumHeads * headDim) : 0; // quant q buffer
// The multiCtasKvMode buffers. Each CTA at most handles 256 rows.
// And the seqLenKv is split into at most mMultiProcessorCount tiles.
workspaces[3] = size * 256 * mMultiProcessorCount * mMLAParams.kv_lora_rank;
workspaces[6] = size * 256 * mMultiProcessorCount * headDim;
// The partialSum size.
workspaces[4] = sizeof(float) * 256 * mMultiProcessorCount;
workspaces[7] = sizeof(float) * 256 * mMultiProcessorCount;
// The partialMax size.
workspaces[5] = sizeof(float) * 256 * mMultiProcessorCount;

workspaces[6] = flash_mla_workspace_size;
workspaces[8] = sizeof(float) * 256 * mMultiProcessorCount;
workspaces[9] = flash_mla_workspace_size;

return tc::calculateTotalWorkspaceSize(workspaces, NUM_BUFFERS);
}
Expand Down Expand Up @@ -864,9 +868,11 @@ int AttentionOp::mlaGeneration(
int const num_kv_heads = 1;
int const head_size = mMLAParams.kv_lora_rank + mMLAParams.qk_rope_head_dim;
int32_t const batch_beam = generation_params.beam_width * generation_params.num_requests;

// The element size of the KV cache.
int elemSize = sizeof(T);
auto const elemSize = mKVCacheQuantMode.hasFp8KvCache() ? sizeof(__nv_fp8_e4m3) : sizeof(T);
auto const sizePerToken = num_kv_heads * head_size * elemSize;
params.cache_type = (mKVCacheQuantMode.hasFp8KvCache() ? KvCacheDataType::FP8 : KvCacheDataType::BASE);

auto kv_cache_buffer = KVBlockArray(batch_beam, generation_params.max_blocks_per_sequence, mTokensPerBlock,
sizePerToken, generation_params.cyclic_attention_window_size,
Expand All @@ -880,15 +886,36 @@ int AttentionOp::mlaGeneration(

size_t const cu_seqlens_size = sizeof(int) * (params.batch_size + 1);
size_t const fmha_scheduler_counter = sizeof(uint32_t);
size_t const mla_bmm1_scale_size = mFP8GenerationMLA ? sizeof(float) * 2 : 0;
size_t const mla_bmm2_scale_size = mFP8GenerationMLA ? sizeof(float) : 0;
size_t const quant_q_buffer_size = mFP8GenerationMLA
? params.acc_q_len * size_t(mNumHeads * (mMLAParams.kv_lora_rank + mMLAParams.qk_rope_head_dim))
: 0;
int* cu_q_seqlens = reinterpret_cast<int*>(nextWorkspacePtr(workspace_byte_ptr, offset, cu_seqlens_size));
int* cu_kv_seqlens = reinterpret_cast<int*>(nextWorkspacePtr(workspace_byte_ptr, offset, cu_seqlens_size));
uint32_t* fmha_tile_counter_ptr
= reinterpret_cast<uint32_t*>(nextWorkspacePtr(workspace_byte_ptr, offset, fmha_scheduler_counter));
float* mla_bmm1_scale_ptr
= reinterpret_cast<float*>(nextWorkspacePtr(workspace_byte_ptr, offset, mla_bmm1_scale_size));
float* mla_bmm2_scale_ptr
= reinterpret_cast<float*>(nextWorkspacePtr(workspace_byte_ptr, offset, mla_bmm2_scale_size));
void* quant_q_buffer_ptr
= reinterpret_cast<__nv_fp8_e4m3*>(nextWorkspacePtr(workspace_byte_ptr, offset, quant_q_buffer_size));
void* scratch_ptr = nextWorkspacePtr(workspace_byte_ptr, offset);

params.seqQOffset = cu_q_seqlens;
params.cu_kv_seqlens = cu_kv_seqlens;
params.fmha_tile_counter = fmha_tile_counter_ptr;
params.bmm1_scale = mla_bmm1_scale_ptr;
params.bmm2_scale = mla_bmm2_scale_ptr;
params.quant_attention_input_buf = quant_q_buffer_ptr;

params.quant_scale_o = generation_params.attention_output_orig_quant;
params.quant_scale_q = generation_params.kv_scale_orig_quant;
params.quant_scale_kv = generation_params.kv_scale_orig_quant;
params.dequant_scale_q = generation_params.kv_scale_quant_orig;
params.dequant_scale_kv = generation_params.kv_scale_quant_orig;
params.host_bmm1_scale = 1 / (sqrt((float) (mMLAParams.qk_nope_head_dim + mMLAParams.qk_rope_head_dim)));

invokeMLARopeGeneration<T>(params, kv_cache_buffer, stream);
sync_check_cuda_error(stream);
Expand Down Expand Up @@ -920,7 +947,8 @@ int AttentionOp::mlaGeneration(
tllmRunnerParams.mTileScheduler = mMultiBlockMode ? TileScheduler::Static : TileScheduler::Persistent;

// Q buffer.
tllmRunnerParams.qPtr = reinterpret_cast<void const*>(params.attention_input_buf);
tllmRunnerParams.qPtr = mFP8GenerationMLA ? reinterpret_cast<void const*>(params.quant_attention_input_buf)
: reinterpret_cast<void const*>(params.attention_input_buf);

// KV buffer
// Paged KV
Expand Down Expand Up @@ -972,8 +1000,14 @@ int AttentionOp::mlaGeneration(
tllmRunnerParams.stream = stream;
tllmRunnerParams.mSfStartTokenIdx = generation_params.start_token_idx_sf;

// Scales for quantization
static constexpr int bmm1_scale_offset = 1;
tllmRunnerParams.outputScalePtr = reinterpret_cast<float const*>(params.bmm2_scale);
tllmRunnerParams.scaleSoftmaxLog2Ptr = reinterpret_cast<float const*>(params.bmm1_scale) + bmm1_scale_offset;

TLLM_CHECK_WITH_INFO(mTllmGenFMHARunner.get(), "mTllmGenFMHARunner not initialized.");
mTllmGenFMHARunner->run(tllmRunnerParams);
sync_check_cuda_error(stream);
}
else if (mUseFlashMLA)
{
Expand Down Expand Up @@ -1034,7 +1068,9 @@ int AttentionOp::mlaGeneration(
flashMlaParams.scale_softmax = softmax_scale;
flashMlaParams.scale_softmax_log2 = float(softmax_scale * M_LOG2E);

flashMlaParams.q_ptr = const_cast<void*>(reinterpret_cast<void const*>(params.attention_input_buf));
flashMlaParams.q_ptr = mFP8GenerationMLA
? const_cast<void*>(reinterpret_cast<void const*>(params.quant_attention_input_buf))
: const_cast<void*>(reinterpret_cast<void const*>(params.attention_input_buf));
flashMlaParams.k_ptr = kv_cache_buffer.mPrimaryPoolPtr;
flashMlaParams.v_ptr = flashMlaParams.k_ptr;
flashMlaParams.o_ptr = reinterpret_cast<void*>(params.context_buf);
Expand All @@ -1059,6 +1095,9 @@ int AttentionOp::mlaGeneration(
flashMlaParams.block_table_batch_stride = generation_params.max_blocks_per_sequence;
flashMlaParams.page_block_size = mTokensPerBlock;

flashMlaParams.descale_q_ptr = const_cast<float*>(params.dequant_scale_q);
flashMlaParams.descale_k_ptr = const_cast<float*>(params.dequant_scale_kv);

flashMlaParams.tile_scheduler_metadata_ptr = tile_scheduler_metadata_ptr;
flashMlaParams.num_sm_parts = num_sm_parts;
flashMlaParams.num_splits_ptr = num_splits_ptr;
Expand All @@ -1068,12 +1107,25 @@ int AttentionOp::mlaGeneration(

if constexpr (std::is_same<T, half>::value)
{
run_mha_fwd_splitkv_mla<cutlass::half_t, 576>(flashMlaParams, stream);
if (mFP8GenerationMLA)
{
TLLM_THROW("FP8 KV cache MLA is only supported for bf16 output");
}
else
{
run_mha_fwd_splitkv_mla<cutlass::half_t, cutlass::half_t, 576>(flashMlaParams, stream);
}
}
else if constexpr (std::is_same<T, __nv_bfloat16>::value)
{

run_mha_fwd_splitkv_mla<cutlass::bfloat16_t, 576>(flashMlaParams, stream);
if (mFP8GenerationMLA)
{
run_mha_fwd_splitkv_mla<cutlass::float_e4m3_t, cutlass::bfloat16_t, 576>(flashMlaParams, stream);
}
else
{
run_mha_fwd_splitkv_mla<cutlass::bfloat16_t, cutlass::bfloat16_t, 576>(flashMlaParams, stream);
}
}
else
{
Expand All @@ -1093,7 +1145,8 @@ int AttentionOp::mlaGeneration(
// fmhaParams.totalKvSeqLen = params.num_tokens;
// Device buffer pointers.
// fmhaParams.qkvPtr = reinterpret_cast<void const*>(params.attention_input);
fmhaParams.qPtr = reinterpret_cast<void const*>(params.attention_input_buf);
fmhaParams.qPtr = mFP8GenerationMLA ? reinterpret_cast<void const*>(params.quant_attention_input_buf)
: reinterpret_cast<void const*>(params.attention_input_buf);
// TODO: add contiguous kv buffer (cross-attention).
fmhaParams.kvPtr = nullptr;

Expand All @@ -1106,8 +1159,8 @@ int AttentionOp::mlaGeneration(
fmhaParams.cuKvSeqLenPtr = cu_kv_seqlens;
fmhaParams.cuMaskRowsPtr = nullptr; // mla not support custorm mask right now
fmhaParams.tileCounterPtr = fmha_tile_counter_ptr;
fmhaParams.scaleBmm1Ptr = nullptr;
fmhaParams.scaleBmm2Ptr = nullptr;
fmhaParams.scaleBmm1Ptr = reinterpret_cast<float const*>(params.bmm1_scale);
fmhaParams.scaleBmm2Ptr = reinterpret_cast<float const*>(params.bmm2_scale);
fmhaParams.stream = stream;
fmhaParams.forceFp32Acc = mFMHAForceFP32Acc;

Expand Down Expand Up @@ -1462,7 +1515,9 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
}
if (mIsMLAEnabled)
{
params.mla_param->cache_type = cache_type;
params.mla_param->cu_q_seqlens = cu_q_seqlens;
params.mla_param->quant_scale_kv = params.kv_scale_orig_quant;
invokeMLARopeContext<T, KVCacheBuffer>(*params.mla_param, kv_cache_buffer, stream);
}
else
Expand Down Expand Up @@ -2225,6 +2280,14 @@ int AttentionOp::initialize() noexcept
mSM == 89 || mSM == 90 || mSM == 100, "FP8 FMHA can only be enabled on sm_89, sm_90 or sm_100.");
}

// Pre-Check of FP8 Generation MLA.
if (mFP8GenerationMLA)
{
TLLM_CHECK_WITH_INFO(mIsMLAEnabled, "FP8 Generation MLA cannot be enabled because MLA is not supported.");
TLLM_CHECK_WITH_INFO(
mSM == 90 || mSM == 100, "FP8 Generation MLA is supported on Hopper or Blackwell architecture.");
}

// Check requirements for FP4 output.
TLLM_CHECK_WITH_INFO(!mFuseFp4Quant || mEnableContextFMHA, "Context FMHA must enable if fuse_fp4_quant is enabled");
TLLM_CHECK_WITH_INFO(!mFuseFp4Quant || (mSM >= 100), "fuse_fp4_quant only supports SM100 and later devices.");
Expand Down Expand Up @@ -2314,6 +2377,12 @@ int AttentionOp::initialize() noexcept
// If FP4 quantization workflow is enabled, set output type to FP4.
fmhaParams.dataTypeOut = DATA_TYPE_E2M1;
}
if (mIsMLAEnabled)
{
// For FP8 MLA, currently context attention is performed in BF16.
fmhaParams.dataTypeOut = DATA_TYPE_BF16;
fmhaParams.dataTypeKv = DATA_TYPE_BF16;
}
// TODO: remove forceFp32Acc from MHARunnerFixedParams after adding host_runtime_perf_knobs to
// bertAttentionPlugin input tensors, so that we can change mLaunchParams.force_fp32_acc value in runtime.
fmhaParams.forceFp32Acc = false;
Expand Down Expand Up @@ -2406,13 +2475,25 @@ int AttentionOp::initialize() noexcept
else
{
// Construct the fmha runner.
// FP8 Generation MLA also uses context FMHA.
if (mFP8GenerationMLA)
{
data_type = DATA_TYPE_E4M3;
}
MHARunnerFixedParams fmhaParams{};
fmhaParams.dataType = data_type;
fmhaParams.dataTypeKv = data_type;
fmhaParams.dataTypeOut = data_type;
// For FP8 MLA generation, the output type is BF16, and the quantization before o_proj is performed
// separately.
if (mFP8GenerationMLA)
{
fmhaParams.dataTypeOut = DATA_TYPE_BF16;
}
// TODO: remove forceFp32Acc from MHARunnerFixedParams after adding host_runtime_perf_knobs to
// bertAttentionPlugin input tensors, so that we can change mLaunchParams.force_fp32_acc value in
// runtime.
fmhaParams.forceFp32Acc = false;
fmhaParams.forceFp32Acc = true;
fmhaParams.attentionMaskType
= useCustomMask() ? ContextAttentionMaskType::CUSTOM_MASK : ContextAttentionMaskType::PADDING;
// TODO: set it to Q_CONTIGUOUS_KV layout for cross-attention.
Expand Down
1 change: 1 addition & 0 deletions cpp/tensorrt_llm/common/attentionOp.h
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,7 @@ class AttentionOp
bool mPosShiftEnabled = false;
bool mPagedContextFMHA = false;
bool mFP8ContextFMHA = false;
bool mFP8GenerationMLA = false;
bool mDenseContextFMHA = false;
bool mHasFullAttentionMask = false;
bool mIsSpecDecodingEnabled = false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/
#pragma once

#include "tensorrt_llm/kernels/multiHeadAttentionCommon.h"
namespace tensorrt_llm
{
namespace kernels
Expand Down Expand Up @@ -381,6 +382,7 @@ extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_128_sof
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_qkv_256_softcapping_sm90_cu_cubin[];
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_sm90_cu_cubin[];
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_576x512_sm90_cu_cubin[];
extern unsigned char cubin_fmha_v2_flash_attention_e4m3_64_64_S_q_paged_kv_576x256_tma_ws_sm90_cu_cubin[];
#endif

#ifndef EXCLUDE_SM_89
Expand Down Expand Up @@ -1633,6 +1635,7 @@ extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_128_softcapp
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_qkv_256_softcapping_sm90_cu_cubin_len;
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_sm90_cu_cubin_len;
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_576x512_sm90_cu_cubin_len;
extern uint32_t cubin_fmha_v2_flash_attention_e4m3_64_64_S_q_paged_kv_576x256_tma_ws_sm90_cu_cubin_len;
#endif

#ifndef EXCLUDE_SM_89
Expand Down Expand Up @@ -3515,6 +3518,8 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 16, 256, 256, 0, 0, 0, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_qkv_256_softcapping_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_qkv_256_softcapping_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_qkv_256_sliding_window_causal_softcapping_sm90_kernel_nl", 49152, 128, 64, 2, 0, false, true, false, true, true, false, true, false},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_causal_sm90_kernel_nl_tiled", 81920, 128, 64, 1, 0, false, true, false, true, true, true, false, false},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 64, 576, 512, 0, 0, 0, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_576x512_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_576x512_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_576x512_sm90_kernel_nl_tiled", 49152, 128, 64, 0, 2, false, true, false, true, true, true, false, false},
// TODO: FMHA FP8 MLA kernel needs to be regenerated.
{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 64, 576, 512, 0, 0, 0, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_64_S_q_paged_kv_576x256_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_64_S_q_paged_kv_576x256_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_64_S_q_paged_kv_576x256_tma_ws_sm90_kernel", 193792, 384, 64, 0, 2, false, true, true, true, false, false, false, false},
#endif

#ifndef EXCLUDE_SM_89
Expand Down
Loading