Skip to content

feat: [Deepseek] Add trtllm-gen MOE FP4 MOE backend #3387

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 48 additions & 1 deletion cpp/tensorrt_llm/common/cudaDriverWrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cudaDriverWrapper.h"

#include "tensorrt_llm/common/logger.h"
#include <cuda.h>

#include <cstdio>
Expand Down Expand Up @@ -175,9 +175,56 @@ CUresult CUDADriverWrapper::cuLaunchKernel(CUfunction f, unsigned int gridDimX,
f, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ, sharedMemBytes, hStream, kernelParams, extra);
}

namespace
{
std::string stringify_launch_config(CUlaunchConfig const& config)
{
std::stringstream ss;

// Grid dimensions (Driver API uses separate fields)
ss << "Grid Dimensions: (" << config.gridDimX << ", " << config.gridDimY << ", " << config.gridDimZ << ")\n";

// Block dimensions
ss << "Block Dimensions: (" << config.blockDimX << ", " << config.blockDimY << ", " << config.blockDimZ << ")\n";

// Shared memory and stream (Driver API uses hStream)
ss << "Shared Memory: " << config.sharedMemBytes << " bytes\n";
ss << "Stream: " << (config.hStream ? "Custom" : "Default") << " (0x" << std::hex
<< reinterpret_cast<uintptr_t>(config.hStream) << ")\n";

// Attributes (Driver API uses value instead of val)
ss << "Attributes (" << config.numAttrs << "):\n";
for (uint i = 0; i < config.numAttrs; ++i)
{
CUlaunchAttribute const& attr = config.attrs[i];
ss << " [" << i << "] ";

switch (attr.id)
{
case CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION:
ss << "Cluster Dimension: (" << attr.value.clusterDim.x << ", " << attr.value.clusterDim.y << ", "
<< attr.value.clusterDim.z << ")";
break;

case CU_LAUNCH_ATTRIBUTE_PRIORITY: ss << "Priority: " << attr.value.priority; break;

// Handle other Driver API attributes here
default: ss << "Unknown Attribute (ID=" << attr.id << ")"; break;
}
ss << "\n";
}

return ss.str();
}
} // namespace

CUresult CUDADriverWrapper::cuLaunchKernelEx(
CUlaunchConfig const* config, CUfunction f, void** kernelParams, void** extra) const
{

TLLM_LOG_DEBUG("Launch config: %s", stringify_launch_config(*config).c_str());
TLLM_CHECK_DEBUG_WITH_INFO(
(extra != nullptr) != (kernelParams != nullptr), "Exactly one of 'extra' and 'kernelParams' should be set.");
return (*_cuLaunchKernelEx)(config, f, kernelParams, extra);
}

Expand Down
41 changes: 40 additions & 1 deletion cpp/tensorrt_llm/kernels/quantization.cu
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ void invokeBatchedFP4Quantization(int b, int m, int n, __nv_fp8_e4m3 const* inpu
dim3 block(std::min(int(n / CVT_FP8_TO_FP4_ELTS_PER_THREAD), 512));
// Get number of blocks per SM (assume we can fully utilize the SM).
int const numBlocksPerSM = 2048 / block.x;
dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM));
dim3 grid(std::min(m, multiProcessorCount * numBlocksPerSM));

// Launch the cvt kernel.
if (useUE8M0)
Expand All @@ -277,6 +277,45 @@ void invokeBatchedFP4Quantization(int b, int m, int n, __nv_fp8_e4m3 const* inpu
}
#endif

__global__ void nvfp4_block_scale_interleave_kernel(
int numbatches, int numRows, int numCols, uint8_t const* SFIn, uint8_t* SFOutput)
{
for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x)
{
for (int batchIdx = 0; batchIdx < numbatches; batchIdx++)
{
for (int colIdx = threadIdx.x; colIdx < numCols; colIdx += blockDim.x)
{
int64_t inOffset = batchIdx * numRows * numCols + rowIdx * numCols + colIdx;
auto sf = SFIn[inOffset];

std::optional<int> batchIdxOpt = batchIdx;
std::optional<int> numRowsOpt = numRows;

// Without batching, the math in get_sf_out_offset is the same as
// int const numSfTilesK = (numCols + 4 - 1) / 4;
// int const tileOffset = ((mi / 128) * numSfTilesK + ki / 4) * 512;
// int const dstIdx = tileOffset + (mi % 32) * 16 + ((mi % 128) / 32) * 4 + ki % 4;
auto dstIdx = get_sf_out_offset_128x4(batchIdxOpt, rowIdx, colIdx, numRowsOpt, numCols * 16);
SFOutput[dstIdx] = sf;
}
}
}
}

// This is intended for weight loading, so m and n are large, b <= 256
void invokeNVFP4BlockScaleInterleave(
int b, int m, int n, uint8_t const* SFIn, uint8_t* SFOutput, int multiProcessorCount, cudaStream_t stream)
{
// Each thread reads 1 int8 value
dim3 block(std::min(n, 1024));
// Get number of blocks per SM (assume we can fully utilize the SM).
int const numBlocksPerSM = 4096 / block.x;
dim3 grid(std::min(m, multiProcessorCount * numBlocksPerSM));

nvfp4_block_scale_interleave_kernel<<<grid, block, 0, stream>>>(b, m, n, SFIn, SFOutput);
}

// Instantiate the function.
template void invokeFP4Quantization(int m, int n, half const* input, float const* SFScale, int64_t* output,
int32_t* SFOuput, bool useUE8M0, FP4QuantizationSFLayout layout, int multiProcessorCount, cudaStream_t stream);
Expand Down
80 changes: 44 additions & 36 deletions cpp/tensorrt_llm/kernels/quantization.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -527,8 +527,7 @@ __device__ uint64_t cvt_warp_fp8_to_fp4(PackedVec<Type>& vec, float SFScaleVal,
}
// Get the output scale.
// Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) * reciprocal(SFScaleVal))
float outputScale
= SFValue != 0 ? reciprocal_approximate_ftz(SFValue * reciprocal_approximate_ftz(SFScaleVal)) : 0.0f;
float outputScale = SFValue != 0 ? SFScaleVal * reciprocal_approximate_ftz(SFValue) : 0.0f;

if (SFout)
{
Expand Down Expand Up @@ -557,6 +556,46 @@ __device__ uint64_t cvt_warp_fp8_to_fp4(PackedVec<Type>& vec, float SFScaleVal,
#endif
}

inline __device__ int64_t get_sf_out_offset_128x4(
std::optional<int> batchIdx, int mIdx, int kIdx, std::optional<int> numRows, int numCols)
{
// SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)]
// --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx]

// batched tensor
// SF layout [numBTiles, numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)]
// --> index [bTileIdx, mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx]

int32_t innerKIdx = (kIdx % 4);
int64_t innerKStride = 1;

int32_t innerMIdx = (mIdx % (32 * 4)) / 32;
int64_t innerMStride = 4 * innerKStride; // 4

// M tile layout [32, 4] is column-major.
int32_t outerMIdx = (mIdx % 32);
int64_t outerMStride = 4 * innerMStride; // 16

int32_t kTileIdx = (kIdx / 4);
int64_t kTileStride = 32 * outerMStride; // 512

// SF vector size 16. We round the "numCols" up to a multiple of 64.
int factor = CVT_FP4_SF_VEC_SIZE * 4;
int32_t numKTiles = (numCols + factor - 1) / factor;
int32_t mTileIdx = mIdx / (32 * 4);
int64_t mTileStride = numKTiles * kTileStride;

// Each SF block has 128 rows so pad rows to the multiple of 128.
int32_t numMTiles = (numRows.value_or(0) + 128 - 1) / 128;
int64_t bTileStride = numMTiles * mTileStride;

// Compute the global offset.
int64_t SFOffset = batchIdx.value_or(0) * bTileStride + mTileIdx * mTileStride + kTileIdx * kTileStride
+ outerMIdx * outerMStride + innerMIdx * innerMStride + innerKIdx * innerKStride;

return SFOffset;
}

template <class SFType, int CVT_FP4_NUM_THREADS_PER_SF>
__device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(std::optional<int> batchIdx, int rowIdx, int colIdx,
std::optional<int> numRows, int numCols, SFType* SFout, FP4QuantizationSFLayout layout)
Expand All @@ -576,40 +615,7 @@ __device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(std::optional<int> batchI
int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF;
int32_t mIdx = rowIdx;

// SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)]
// --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx]

// batched tensor
// SF layout [numBTiles, numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)]
// --> index [bTileIdx, mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx]

int32_t innerKIdx = (kIdx % 4);
int64_t innerKStride = 1;

int32_t innerMIdx = (mIdx % (32 * 4)) / 32;
int64_t innerMStride = 4 * innerKStride; // 4

// M tile layout [32, 4] is column-major.
int32_t outerMIdx = (mIdx % 32);
int64_t outerMStride = 4 * innerMStride; // 16

int32_t kTileIdx = (kIdx / 4);
int64_t kTileStride = 32 * outerMStride; // 512

// SF vector size 16. We round the "numCols" up to a multiple of 64.
int factor = CVT_FP4_SF_VEC_SIZE * 4;
int32_t numKTiles = (numCols + factor - 1) / factor;
int32_t mTileIdx = mIdx / (32 * 4);
int64_t mTileStride = numKTiles * kTileStride;

// Each SF block has 128 rows so pad rows to the multiple of 128.
int32_t numMTiles = (numRows.value_or(0) + 128 - 1) / 128;
int64_t bTileStride = numMTiles * mTileStride;

// Compute the global offset.
int64_t SFOffset = batchIdx.value_or(0) * bTileStride + mTileIdx * mTileStride + kTileIdx * kTileStride
+ outerMIdx * outerMStride + innerMIdx * innerMStride + innerKIdx * innerKStride;

auto SFOffset = get_sf_out_offset_128x4(batchIdx, mIdx, kIdx, numRows, numCols);
return reinterpret_cast<uint8_t*>(SFout) + SFOffset;
}
else if (layout == FP4QuantizationSFLayout::LINEAR)
Expand Down Expand Up @@ -819,5 +825,7 @@ cvt_fp8_to_fp4(
#endif
}

__global__ void nvfp4_block_scale_interleave_kernel(
int numbatches, int numRows, int numCols, uint8_t const* SFIn, uint8_t* SFOutput);
} // namespace kernels
} // namespace tensorrt_llm
3 changes: 3 additions & 0 deletions cpp/tensorrt_llm/kernels/quantization.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,5 +74,8 @@ template <typename T>
void invokeBatchedFP4Quantization(int b, int m, int n, T const* input, float const* globalScale, int64_t* output,
int32_t* SFOuput, bool useUE8M0, int multiProcessorCount, cudaStream_t stream = 0);

void invokeNVFP4BlockScaleInterleave(
int b, int m, int n, uint8_t const* SFIn, uint8_t* SFOutput, int multiProcessorCount, cudaStream_t stream = 0);

} // namespace kernels
} // namespace tensorrt_llm
2 changes: 1 addition & 1 deletion cpp/tensorrt_llm/kernels/trtllmGenKernels/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,6 @@

add_subdirectory(blockscaleGemm)
add_subdirectory(fmha)
add_subdirectory(fp8BlockScaleMoe)
add_subdirectory(gemm)
add_subdirectory(batchedGemm)
add_subdirectory(blockScaleMoe)
Loading