Skip to content

Commit 31624b0

Browse files
hlu1zongfeijing
andauthored
feat: [Deepseek] Add trtllm-gen MOE FP4 MOE backend (#3387)
* Add TRT-LLM Gen MOE to Deepseek fix fused moe rebase bug. Fix atol in test_fp4_gemm_quantize.py fix fused moe rebase bug. Fix FusedMoe. Disable 2nd routing kernel preexit Bump routing reduction to fp32 Disable PDL for fc1 [DEBUG] Lift token limit to 16k [Bugfix] Token limit to 16k + fp32 routing + tanh Make fp8 tileN 8 Fix FP8 MoE + Remove redundent temp output for FP4 [FP8-only] Avoid wasting CTAs for activation kernel fix: unblock FP8 weightloading with trtllm-gen Remove max_token limit for trtllm-gen path perf: avoid type-conversion and fill_ from aten Minor fix Signed-off-by: Hao Lu <[email protected]> * Fix rebase issues Signed-off-by: Hao Lu <[email protected]> * Fix compile issue Signed-off-by: Zongfei Jing <[email protected]> * CI clean Signed-off-by: Zongfei Jing <[email protected]> --------- Signed-off-by: Hao Lu <[email protected]> Signed-off-by: Zongfei Jing <[email protected]> Co-authored-by: Zongfei Jing <[email protected]>
1 parent 48db263 commit 31624b0

File tree

60 files changed

+95448
-22338
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

60 files changed

+95448
-22338
lines changed

cpp/tensorrt_llm/common/cudaDriverWrapper.cpp

+48-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030

3131
#include "tensorrt_llm/common/assert.h"
3232
#include "tensorrt_llm/common/cudaDriverWrapper.h"
33-
33+
#include "tensorrt_llm/common/logger.h"
3434
#include <cuda.h>
3535

3636
#include <cstdio>
@@ -175,9 +175,56 @@ CUresult CUDADriverWrapper::cuLaunchKernel(CUfunction f, unsigned int gridDimX,
175175
f, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ, sharedMemBytes, hStream, kernelParams, extra);
176176
}
177177

178+
namespace
179+
{
180+
std::string stringify_launch_config(CUlaunchConfig const& config)
181+
{
182+
std::stringstream ss;
183+
184+
// Grid dimensions (Driver API uses separate fields)
185+
ss << "Grid Dimensions: (" << config.gridDimX << ", " << config.gridDimY << ", " << config.gridDimZ << ")\n";
186+
187+
// Block dimensions
188+
ss << "Block Dimensions: (" << config.blockDimX << ", " << config.blockDimY << ", " << config.blockDimZ << ")\n";
189+
190+
// Shared memory and stream (Driver API uses hStream)
191+
ss << "Shared Memory: " << config.sharedMemBytes << " bytes\n";
192+
ss << "Stream: " << (config.hStream ? "Custom" : "Default") << " (0x" << std::hex
193+
<< reinterpret_cast<uintptr_t>(config.hStream) << ")\n";
194+
195+
// Attributes (Driver API uses value instead of val)
196+
ss << "Attributes (" << config.numAttrs << "):\n";
197+
for (uint i = 0; i < config.numAttrs; ++i)
198+
{
199+
CUlaunchAttribute const& attr = config.attrs[i];
200+
ss << " [" << i << "] ";
201+
202+
switch (attr.id)
203+
{
204+
case CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION:
205+
ss << "Cluster Dimension: (" << attr.value.clusterDim.x << ", " << attr.value.clusterDim.y << ", "
206+
<< attr.value.clusterDim.z << ")";
207+
break;
208+
209+
case CU_LAUNCH_ATTRIBUTE_PRIORITY: ss << "Priority: " << attr.value.priority; break;
210+
211+
// Handle other Driver API attributes here
212+
default: ss << "Unknown Attribute (ID=" << attr.id << ")"; break;
213+
}
214+
ss << "\n";
215+
}
216+
217+
return ss.str();
218+
}
219+
} // namespace
220+
178221
CUresult CUDADriverWrapper::cuLaunchKernelEx(
179222
CUlaunchConfig const* config, CUfunction f, void** kernelParams, void** extra) const
180223
{
224+
225+
TLLM_LOG_DEBUG("Launch config: %s", stringify_launch_config(*config).c_str());
226+
TLLM_CHECK_DEBUG_WITH_INFO(
227+
(extra != nullptr) != (kernelParams != nullptr), "Exactly one of 'extra' and 'kernelParams' should be set.");
181228
return (*_cuLaunchKernelEx)(config, f, kernelParams, extra);
182229
}
183230

cpp/tensorrt_llm/kernels/quantization.cu

+40-1
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ void invokeBatchedFP4Quantization(int b, int m, int n, __nv_fp8_e4m3 const* inpu
259259
dim3 block(std::min(int(n / CVT_FP8_TO_FP4_ELTS_PER_THREAD), 512));
260260
// Get number of blocks per SM (assume we can fully utilize the SM).
261261
int const numBlocksPerSM = 2048 / block.x;
262-
dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM));
262+
dim3 grid(std::min(m, multiProcessorCount * numBlocksPerSM));
263263

264264
// Launch the cvt kernel.
265265
if (useUE8M0)
@@ -277,6 +277,45 @@ void invokeBatchedFP4Quantization(int b, int m, int n, __nv_fp8_e4m3 const* inpu
277277
}
278278
#endif
279279

280+
__global__ void nvfp4_block_scale_interleave_kernel(
281+
int numbatches, int numRows, int numCols, uint8_t const* SFIn, uint8_t* SFOutput)
282+
{
283+
for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x)
284+
{
285+
for (int batchIdx = 0; batchIdx < numbatches; batchIdx++)
286+
{
287+
for (int colIdx = threadIdx.x; colIdx < numCols; colIdx += blockDim.x)
288+
{
289+
int64_t inOffset = batchIdx * numRows * numCols + rowIdx * numCols + colIdx;
290+
auto sf = SFIn[inOffset];
291+
292+
std::optional<int> batchIdxOpt = batchIdx;
293+
std::optional<int> numRowsOpt = numRows;
294+
295+
// Without batching, the math in get_sf_out_offset is the same as
296+
// int const numSfTilesK = (numCols + 4 - 1) / 4;
297+
// int const tileOffset = ((mi / 128) * numSfTilesK + ki / 4) * 512;
298+
// int const dstIdx = tileOffset + (mi % 32) * 16 + ((mi % 128) / 32) * 4 + ki % 4;
299+
auto dstIdx = get_sf_out_offset_128x4(batchIdxOpt, rowIdx, colIdx, numRowsOpt, numCols * 16);
300+
SFOutput[dstIdx] = sf;
301+
}
302+
}
303+
}
304+
}
305+
306+
// This is intended for weight loading, so m and n are large, b <= 256
307+
void invokeNVFP4BlockScaleInterleave(
308+
int b, int m, int n, uint8_t const* SFIn, uint8_t* SFOutput, int multiProcessorCount, cudaStream_t stream)
309+
{
310+
// Each thread reads 1 int8 value
311+
dim3 block(std::min(n, 1024));
312+
// Get number of blocks per SM (assume we can fully utilize the SM).
313+
int const numBlocksPerSM = 4096 / block.x;
314+
dim3 grid(std::min(m, multiProcessorCount * numBlocksPerSM));
315+
316+
nvfp4_block_scale_interleave_kernel<<<grid, block, 0, stream>>>(b, m, n, SFIn, SFOutput);
317+
}
318+
280319
// Instantiate the function.
281320
template void invokeFP4Quantization(int m, int n, half const* input, float const* SFScale, int64_t* output,
282321
int32_t* SFOuput, bool useUE8M0, FP4QuantizationSFLayout layout, int multiProcessorCount, cudaStream_t stream);

cpp/tensorrt_llm/kernels/quantization.cuh

+44-36
Original file line numberDiff line numberDiff line change
@@ -527,8 +527,7 @@ __device__ uint64_t cvt_warp_fp8_to_fp4(PackedVec<Type>& vec, float SFScaleVal,
527527
}
528528
// Get the output scale.
529529
// Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) * reciprocal(SFScaleVal))
530-
float outputScale
531-
= SFValue != 0 ? reciprocal_approximate_ftz(SFValue * reciprocal_approximate_ftz(SFScaleVal)) : 0.0f;
530+
float outputScale = SFValue != 0 ? SFScaleVal * reciprocal_approximate_ftz(SFValue) : 0.0f;
532531

533532
if (SFout)
534533
{
@@ -557,6 +556,46 @@ __device__ uint64_t cvt_warp_fp8_to_fp4(PackedVec<Type>& vec, float SFScaleVal,
557556
#endif
558557
}
559558

559+
inline __device__ int64_t get_sf_out_offset_128x4(
560+
std::optional<int> batchIdx, int mIdx, int kIdx, std::optional<int> numRows, int numCols)
561+
{
562+
// SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)]
563+
// --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx]
564+
565+
// batched tensor
566+
// SF layout [numBTiles, numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)]
567+
// --> index [bTileIdx, mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx]
568+
569+
int32_t innerKIdx = (kIdx % 4);
570+
int64_t innerKStride = 1;
571+
572+
int32_t innerMIdx = (mIdx % (32 * 4)) / 32;
573+
int64_t innerMStride = 4 * innerKStride; // 4
574+
575+
// M tile layout [32, 4] is column-major.
576+
int32_t outerMIdx = (mIdx % 32);
577+
int64_t outerMStride = 4 * innerMStride; // 16
578+
579+
int32_t kTileIdx = (kIdx / 4);
580+
int64_t kTileStride = 32 * outerMStride; // 512
581+
582+
// SF vector size 16. We round the "numCols" up to a multiple of 64.
583+
int factor = CVT_FP4_SF_VEC_SIZE * 4;
584+
int32_t numKTiles = (numCols + factor - 1) / factor;
585+
int32_t mTileIdx = mIdx / (32 * 4);
586+
int64_t mTileStride = numKTiles * kTileStride;
587+
588+
// Each SF block has 128 rows so pad rows to the multiple of 128.
589+
int32_t numMTiles = (numRows.value_or(0) + 128 - 1) / 128;
590+
int64_t bTileStride = numMTiles * mTileStride;
591+
592+
// Compute the global offset.
593+
int64_t SFOffset = batchIdx.value_or(0) * bTileStride + mTileIdx * mTileStride + kTileIdx * kTileStride
594+
+ outerMIdx * outerMStride + innerMIdx * innerMStride + innerKIdx * innerKStride;
595+
596+
return SFOffset;
597+
}
598+
560599
template <class SFType, int CVT_FP4_NUM_THREADS_PER_SF>
561600
__device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(std::optional<int> batchIdx, int rowIdx, int colIdx,
562601
std::optional<int> numRows, int numCols, SFType* SFout, FP4QuantizationSFLayout layout)
@@ -576,40 +615,7 @@ __device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(std::optional<int> batchI
576615
int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF;
577616
int32_t mIdx = rowIdx;
578617

579-
// SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)]
580-
// --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx]
581-
582-
// batched tensor
583-
// SF layout [numBTiles, numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)]
584-
// --> index [bTileIdx, mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx]
585-
586-
int32_t innerKIdx = (kIdx % 4);
587-
int64_t innerKStride = 1;
588-
589-
int32_t innerMIdx = (mIdx % (32 * 4)) / 32;
590-
int64_t innerMStride = 4 * innerKStride; // 4
591-
592-
// M tile layout [32, 4] is column-major.
593-
int32_t outerMIdx = (mIdx % 32);
594-
int64_t outerMStride = 4 * innerMStride; // 16
595-
596-
int32_t kTileIdx = (kIdx / 4);
597-
int64_t kTileStride = 32 * outerMStride; // 512
598-
599-
// SF vector size 16. We round the "numCols" up to a multiple of 64.
600-
int factor = CVT_FP4_SF_VEC_SIZE * 4;
601-
int32_t numKTiles = (numCols + factor - 1) / factor;
602-
int32_t mTileIdx = mIdx / (32 * 4);
603-
int64_t mTileStride = numKTiles * kTileStride;
604-
605-
// Each SF block has 128 rows so pad rows to the multiple of 128.
606-
int32_t numMTiles = (numRows.value_or(0) + 128 - 1) / 128;
607-
int64_t bTileStride = numMTiles * mTileStride;
608-
609-
// Compute the global offset.
610-
int64_t SFOffset = batchIdx.value_or(0) * bTileStride + mTileIdx * mTileStride + kTileIdx * kTileStride
611-
+ outerMIdx * outerMStride + innerMIdx * innerMStride + innerKIdx * innerKStride;
612-
618+
auto SFOffset = get_sf_out_offset_128x4(batchIdx, mIdx, kIdx, numRows, numCols);
613619
return reinterpret_cast<uint8_t*>(SFout) + SFOffset;
614620
}
615621
else if (layout == FP4QuantizationSFLayout::LINEAR)
@@ -819,5 +825,7 @@ cvt_fp8_to_fp4(
819825
#endif
820826
}
821827

828+
__global__ void nvfp4_block_scale_interleave_kernel(
829+
int numbatches, int numRows, int numCols, uint8_t const* SFIn, uint8_t* SFOutput);
822830
} // namespace kernels
823831
} // namespace tensorrt_llm

cpp/tensorrt_llm/kernels/quantization.h

+3
Original file line numberDiff line numberDiff line change
@@ -74,5 +74,8 @@ template <typename T>
7474
void invokeBatchedFP4Quantization(int b, int m, int n, T const* input, float const* globalScale, int64_t* output,
7575
int32_t* SFOuput, bool useUE8M0, int multiProcessorCount, cudaStream_t stream = 0);
7676

77+
void invokeNVFP4BlockScaleInterleave(
78+
int b, int m, int n, uint8_t const* SFIn, uint8_t* SFOutput, int multiProcessorCount, cudaStream_t stream = 0);
79+
7780
} // namespace kernels
7881
} // namespace tensorrt_llm

cpp/tensorrt_llm/kernels/trtllmGenKernels/CMakeLists.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,6 @@
1717

1818
add_subdirectory(blockscaleGemm)
1919
add_subdirectory(fmha)
20-
add_subdirectory(fp8BlockScaleMoe)
2120
add_subdirectory(gemm)
2221
add_subdirectory(batchedGemm)
22+
add_subdirectory(blockScaleMoe)

0 commit comments

Comments
 (0)