Skip to content

Commit dc9937c

Browse files
pamelap-nvidiaschetlur-nv
authored andcommitted
feat: Add FP8 support for SM 120 (NVIDIA#3248)
* Allow FP8 on SM120 Signed-off-by: Pamela Peng <[email protected]> * fix sm121 Signed-off-by: Pamela Peng <[email protected]> * fix Signed-off-by: Pamela Peng <[email protected]> * fix pre-commit Signed-off-by: Pamela Peng <[email protected]> * review update Signed-off-by: Pamela Peng <[email protected]> --------- Signed-off-by: Pamela Peng <[email protected]> Co-authored-by: Sharan Chetlur <[email protected]> Signed-off-by: Luis Vega <[email protected]>
1 parent a09aa4f commit dc9937c

File tree

26 files changed

+70
-60
lines changed

26 files changed

+70
-60
lines changed

cpp/micro_benchmarks/mixtureOfExpertsBackendBenchmarkFixture.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,7 @@ class MixtureOfExpertsBenchmark : public ::benchmark::Fixture
383383
static_assert(!FP4, "FP4 Tests enabled on unsupported CUDA version");
384384
#endif
385385
bool should_skip_unsupported_fp8 = getSMVersion() < 89 && FP8;
386-
bool should_skip_unsupported_fp4 = getSMVersion() < 100 && FP4;
386+
bool should_skip_unsupported_fp4 = (getSMVersion() < 100 || getSMVersion() >= 120) && FP4;
387387
return should_skip_unsupported_fp8 || should_skip_unsupported_fp4;
388388
}
389389

cpp/tensorrt_llm/common/attentionOp.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ bool AttentionOp::convertMMHAParamsToXQAParams(tensorrt_llm::kernels::XQAParams&
209209
xqaParams.kv_cache_data_type = xqaParams.data_type;
210210
}
211211
if (xqaParams.kv_cache_data_type == DATA_TYPE_INT8
212-
|| (xqaParams.kv_cache_data_type == DATA_TYPE_E4M3 && mSM < kSM_90))
212+
|| (xqaParams.kv_cache_data_type == DATA_TYPE_E4M3 && (mSM < kSM_90 || mSM >= kSM_120)))
213213
{
214214
xqaParams.multi_block_mode = false;
215215
}
@@ -2276,8 +2276,8 @@ int AttentionOp::initialize() noexcept
22762276
if (mFP8ContextFMHA)
22772277
{
22782278
TLLM_CHECK_WITH_INFO(mEnableContextFMHA, "FP8 FMHA cannot be enabled because Context FMHA is not supported.");
2279-
TLLM_CHECK_WITH_INFO(
2280-
mSM == 89 || mSM == 90 || mSM == 100, "FP8 FMHA can only be enabled on sm_89, sm_90 or sm_100.");
2279+
TLLM_CHECK_WITH_INFO(mSM == 89 || mSM == 90 || mSM == 100 || mSM == 120,
2280+
"FP8 FMHA can only be enabled on sm_89, sm_90, sm_100 or sm_120.");
22812281
}
22822282

22832283
// Pre-Check of FP8 Generation MLA.
@@ -2290,7 +2290,7 @@ int AttentionOp::initialize() noexcept
22902290

22912291
// Check requirements for FP4 output.
22922292
TLLM_CHECK_WITH_INFO(!mFuseFp4Quant || mEnableContextFMHA, "Context FMHA must enable if fuse_fp4_quant is enabled");
2293-
TLLM_CHECK_WITH_INFO(!mFuseFp4Quant || (mSM >= 100), "fuse_fp4_quant only supports SM100 and later devices.");
2293+
TLLM_CHECK_WITH_INFO(!mFuseFp4Quant || mSM == 100, "fuse_fp4_quant only supports SM100 devices.");
22942294

22952295
TLLM_CHECK(isRoPE() == (mRotaryEmbeddingDim != 0));
22962296
TLLM_CHECK_WITH_INFO((mSM >= 80) || (mType != nvinfer1::DataType::kBF16),

cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -415,7 +415,7 @@ void FusedMHARunnerV2::setupLaunchParams(MHARunnerParams runnerParams)
415415
mLaunchParams.kernel_s = 0;
416416
mLaunchParams.force_unroll = true;
417417
// enable tiled kernels on Ampere/Ada
418-
if (isSm89 && mFixedParams.dataType == DATA_TYPE_E4M3)
418+
if ((isSm89 || isSm120) && mFixedParams.dataType == DATA_TYPE_E4M3)
419419
{
420420
// so far Ada QMMA only supports non-tiled kernels.
421421
mLaunchParams.granular_tiling = false;
@@ -427,7 +427,7 @@ void FusedMHARunnerV2::setupLaunchParams(MHARunnerParams runnerParams)
427427
// can suffer from tile quantization loss therefore use flash attention non-tiled instead
428428
mLaunchParams.granular_tiling = false;
429429
}
430-
else if (isSm8x && mFixedParams.headSize < 256)
430+
else if ((isSm8x || isSm120) && mFixedParams.headSize < 256)
431431
{
432432
// flash attention tiled kernel is faster on Ada and Ampere derivatives when head_size>=256
433433
mLaunchParams.granular_tiling = false;

cpp/tensorrt_llm/kernels/customAllReduceKernels.cu

+14-14
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ __global__ void rms_norm_kernel(AllReduceParams params)
266266
local_final_output_buffer += block_offset;
267267
intermediate_buffer += block_offset;
268268

269-
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
269+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDA_ARCH__ < 1200))
270270
cudaGridDependencySynchronize();
271271
#endif
272272

@@ -309,7 +309,7 @@ __global__ void rms_norm_kernel(AllReduceParams params)
309309
inter_vec.packed = rms_norm<T, Affine>(denom, inter_vec, weight_vec);
310310
*reinterpret_cast<int4*>(&local_final_output_buffer[offset]) = inter_vec.packed;
311311
}
312-
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
312+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDA_ARCH__ < 1200))
313313
cudaTriggerProgrammaticLaunchCompletion();
314314
#endif
315315
}
@@ -340,7 +340,7 @@ __global__ void rms_pre_post_norm_kernel(AllReduceParams params) // for gemma2 p
340340
local_final_output_buffer += block_offset;
341341
intermediate_buffer += block_offset;
342342

343-
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
343+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDA_ARCH__ < 1200))
344344
cudaGridDependencySynchronize();
345345
#endif
346346

@@ -393,7 +393,7 @@ __global__ void rms_pre_post_norm_kernel(AllReduceParams params) // for gemma2 p
393393
inter_vec.packed = rms_norm<T, Affine>(denom, inter_vec, weight_vec);
394394
*reinterpret_cast<int4*>(&local_final_output_buffer[offset]) = inter_vec.packed;
395395
}
396-
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
396+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDA_ARCH__ < 1200))
397397
cudaTriggerProgrammaticLaunchCompletion();
398398
#endif
399399
}
@@ -744,7 +744,7 @@ struct Reducer<T, RanksPerNode, false>
744744
template <int ClusterSize, typename T, int RanksPerNode, bool Bias = false, bool Affine = false, bool PushMode = true>
745745
static __global__ void lamport_style_one_shot_all_reduce_norm_kernel(AllReduceParams params)
746746
{
747-
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
747+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDA_ARCH__ < 1200))
748748
namespace cg = cooperative_groups;
749749
static_assert(RanksPerNode <= MAX_RANKS_PER_NODE);
750750
static constexpr int kPackedSize = details::kBytesPerAccess / sizeof(T);
@@ -937,7 +937,7 @@ static __global__ void __launch_bounds__(1024, 1) one_shot_all_reduce_norm_kerne
937937
buffers[ii] = reinterpret_cast<T*>(params.peer_comm_buffer_ptrs[rank]);
938938
}
939939

940-
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
940+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDA_ARCH__ < 1200))
941941
cudaGridDependencySynchronize();
942942
#endif
943943

@@ -1001,7 +1001,7 @@ static __global__ void __launch_bounds__(1024, 1) one_shot_all_reduce_norm_kerne
10011001
*reinterpret_cast<int4*>(&local_final_output_buffer[norm_offset + offset]) = sum_vec.packed;
10021002
}
10031003
}
1004-
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
1004+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDA_ARCH__ < 1200))
10051005
cudaTriggerProgrammaticLaunchCompletion();
10061006
#endif
10071007
}
@@ -1044,7 +1044,7 @@ static __global__ void __launch_bounds__(1024, 1) one_shot_prenorm_all_reduce_no
10441044
buffers[ii] = reinterpret_cast<T*>(params.peer_comm_buffer_ptrs[rank]);
10451045
}
10461046

1047-
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
1047+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDA_ARCH__ < 1200))
10481048
cudaGridDependencySynchronize();
10491049
#endif
10501050

@@ -1114,7 +1114,7 @@ static __global__ void __launch_bounds__(1024, 1) one_shot_prenorm_all_reduce_no
11141114
sum_vec.packed = rms_norm<T, Affine>(denom, sum_vec, weight_vec);
11151115
*reinterpret_cast<int4*>(&local_final_output_buffer[norm_offset + thread_offset]) = sum_vec.packed;
11161116
}
1117-
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
1117+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDA_ARCH__ < 1200))
11181118
cudaTriggerProgrammaticLaunchCompletion();
11191119
#endif
11201120
}
@@ -1128,7 +1128,7 @@ bool is_lamport_supported(int token_num, int hidden_size)
11281128
if (disableLamportReduceNormFusion)
11291129
return false;
11301130
static int sm = tensorrt_llm::common::getSMVersion();
1131-
if (sm < 90)
1131+
if (sm < 90 || sm >= 120)
11321132
{
11331133
return false;
11341134
}
@@ -1355,7 +1355,7 @@ static __global__ void oneShotAllReduceKernel(AllReduceParams params)
13551355
buffers[ii] = reinterpret_cast<T*>(params.peer_comm_buffer_ptrs[rank]);
13561356
}
13571357

1358-
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
1358+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDA_ARCH__ < 1200))
13591359
cudaGridDependencySynchronize();
13601360
#endif
13611361

@@ -1424,7 +1424,7 @@ static __global__ void oneShotAllReduceKernel(AllReduceParams params)
14241424
*reinterpret_cast<int4*>(&local_output_buffer[iter_offset]) = sums.packed;
14251425
}
14261426

1427-
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
1427+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDA_ARCH__ < 1200))
14281428
cudaTriggerProgrammaticLaunchCompletion();
14291429
#endif
14301430
}
@@ -1497,7 +1497,7 @@ static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduc
14971497
buffers[ii] = reinterpret_cast<T*>(params.peer_comm_buffer_ptrs[rank]);
14981498
}
14991499

1500-
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
1500+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDA_ARCH__ < 1200))
15011501
cudaGridDependencySynchronize();
15021502
#endif
15031503

@@ -1631,7 +1631,7 @@ static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduc
16311631
}
16321632
}
16331633

1634-
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
1634+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDA_ARCH__ < 1200))
16351635
cudaTriggerProgrammaticLaunchCompletion();
16361636
#endif
16371637
}

cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ std::vector<CutlassTileConfig> get_candidate_tiles(
175175
case CutlassGemmType::Fp8:
176176
if (config_type_param & CutlassGemmConfig::GROUPED_GEMM)
177177
{
178-
if (sm == 89)
178+
if (sm == 89 || sm >= 120)
179179
{
180180
return {CutlassTileConfig::CtaShape16x256x128_WarpShape16x64x128,
181181
CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
@@ -193,7 +193,7 @@ std::vector<CutlassTileConfig> get_candidate_tiles(
193193
}
194194
else
195195
{
196-
if (sm == 89)
196+
if (sm == 89 || sm >= 120)
197197
{
198198
return {CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
199199
CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64,
@@ -414,7 +414,7 @@ std::vector<CutlassGemmConfig> get_candidate_configs(
414414
{
415415
return get_candidate_configs_sm90(config_type_param);
416416
}
417-
if (sm >= 100 && sm != 120 && (config_type_param & CutlassGemmConfig::BLACKWELL))
417+
if (sm >= 100 && sm < 120 && (config_type_param & CutlassGemmConfig::BLACKWELL))
418418
{
419419
return get_candidate_configs_sm100(config_type_param);
420420
}

cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_preprocessors.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -571,7 +571,7 @@ void preprocess_weights_for_mixed_gemm(int8_t* preprocessed_quantized_weight, in
571571
arch = 80;
572572
}
573573
// Force use sm80 kernel for GB20x.
574-
if (arch == 120)
574+
if (arch >= 120)
575575
{
576576
arch = 80;
577577
}

cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_rowwise_gemm/fp8_rowwise_gemm_template.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -508,7 +508,7 @@ size_t CutlassFp8RowwiseGemmRunner<T>::dispatchToArch(void* D, void const* A, vo
508508
return dispatchGemmToCutlassSm90<T>(D, A, B, C_bias, quantOption, m, n, k, scale_d0, scale_d1, gemmConfig,
509509
workspace, workspaceBytes, stream, occupancy);
510510
}
511-
else if (mSm == 89)
511+
else if (mSm == 89 || mSm >= 120)
512512
{
513513
return dispatchGemmToCutlassSm89<T>(D, A, B, C_bias, quantOption, m, n, k, scale_d0, scale_d1, gemmConfig,
514514
workspace, workspaceBytes, stream, occupancy);
@@ -574,7 +574,7 @@ std::vector<tkc::CutlassGemmConfig> CutlassFp8RowwiseGemmRunner<T>::getConfigs()
574574
}
575575
}
576576
}
577-
else if (mSm == 89)
577+
else if (mSm == 89 || mSm >= 120)
578578
{
579579
tkc::CutlassGemmConfig::CandidateConfigTypeParam config_type_param
580580
= tkc::CutlassGemmConfig::CandidateConfigTypeParam::FP8_ONLY;

cpp/tensorrt_llm/kernels/cutlass_kernels/int8_gemm/int8_gemm_template.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ void CutlassInt8GemmRunner<T>::dispatchToArch(int8_t const* A, int8_t const* B,
334334
dispatchGemmToCutlass<T, cutlass::arch::Sm75>(A, B, quantOption, alphaCol, alphaRow, C, m, n, k, workspacePtr,
335335
workspaceBytes, gemmConfig, stream, occupancy);
336336
}
337-
else if (mSm >= 80 && mSm <= 90)
337+
else if (mSm >= 80 && mSm <= 90 || mSm >= 120)
338338
{
339339
dispatchGemmToCutlass<T, cutlass::arch::Sm80>(A, B, quantOption, alphaCol, alphaRow, C, m, n, k, workspacePtr,
340340
workspaceBytes, gemmConfig, stream, occupancy);

cpp/tensorrt_llm/kernels/cutlass_kernels/python/generate_kernels.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,7 @@ def is_grouped_gemm_op_valid(op):
393393

394394

395395
def is_op_valid(op):
396-
if op.arch >= 100:
396+
if op.arch >= 100 and op.arch < 120:
397397
return is_gemm_op_valid_sm100(op)
398398

399399
if op.gemm_kind == GemmKind.Gemm:
@@ -666,7 +666,8 @@ def has_arch(sm):
666666
operations = []
667667
operations += generate_sm100_operations(has_arch(100))
668668
operations += generate_sm90_operations(has_arch(90))
669-
operations += generate_sm80_operations(has_arch(80) or has_arch(89))
669+
operations += generate_sm80_operations(
670+
has_arch(80) or has_arch(89) or has_arch(120))
670671

671672
def should_skip(op):
672673
is_internal = op.gemm_kind == GemmKind.Grouped

cpp/tensorrt_llm/kernels/mambaConv1dKernels.cu

+1-1
Original file line numberDiff line numberDiff line change
@@ -793,7 +793,7 @@ void invokeMambaConv1dContext(MambaConv1dParamsBase& params, cudaStream_t stream
793793

794794
if (std::is_same_v<input_t, float>)
795795
{
796-
if (tensorrt_llm::common::getSMVersion() >= 90)
796+
if (tensorrt_llm::common::getSMVersion() >= 90 && tensorrt_llm::common::getSMVersion() < 120)
797797
{
798798
if (B * L * D <= 262144)
799799
{

cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelLauncher.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ inline void kernel_launcher(int arch, Params& params, cudaStream_t s)
5353
}
5454
else if ((arch >= 80 && arch < 90) || arch >= 100)
5555
{
56-
if (arch == 89)
56+
if (arch == 89 || arch >= 120)
5757
{
5858
EXEC_W4A8(KernelType::FP16Int4Groupwise, FP16DetailsA, Int4DetailsW, ColumnMajorInterleaved, true);
5959
EXEC_W4A8(KernelType::BF16Int4Groupwise, BF16DetailsA, Int4DetailsW, ColumnMajorInterleaved, true);

cpp/tensorrt_llm/plugins/fp4GemmPlugin/fp4GemmPlugin.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,8 @@ Fp4GemmPlugin::Fp4GemmPlugin(void const* data, size_t length, Fp4GemmPlugin::Plu
116116

117117
void Fp4GemmPlugin::init(nvinfer1::DataType type)
118118
{
119-
TLLM_CHECK_WITH_INFO((getSMVersion() >= 100), "FP4 Gemm not supported before Blackwell");
119+
TLLM_CHECK_WITH_INFO((getSMVersion() >= 100 && getSMVersion() < 120),
120+
"FP4 Gemm not supported before Blackwell, nor GeForce Blackwell");
120121
TLLM_CHECK_WITH_INFO(
121122
(mOutputType == DataType::kBF16) || (mOutputType == DataType::kFLOAT) || (mOutputType == DataType::kHALF),
122123
"Only support float, half, bfloat16, got %d.", (int) mOutputType);

cpp/tensorrt_llm/plugins/gemmPlugin/gemmPlugin.cpp

+4-3
Original file line numberDiff line numberDiff line change
@@ -389,17 +389,18 @@ int GemmPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::P
389389
}
390390

391391
bool cudaKernelFinished = false;
392+
bool isArch90or100 = mArch >= 90 && mArch < 120;
392393
// TODO: sub tensor matmul is not supported in fp8 gemm cuda kernel
393-
if (mArch < 90 && M <= 4 && N <= 128000 && mUseFp8 && noPadDim && cudaKernelSupportType)
394+
if (!isArch90or100 && M <= 4 && N <= 128000 && mUseFp8 && noPadDim && cudaKernelSupportType)
394395
{
395396
tensorrt_llm::common::QuantMode quantMode = tensorrt_llm::common::QuantMode::fromQuantAlgo("FP8");
396397
tensorrt_llm::kernels::cuda_core_gemm::Params params(reinterpret_cast<void const*>(inputs[0]),
397398
reinterpret_cast<void const*>(inputs[1]), mAlpha, reinterpret_cast<void*>(outputs[0]), M, N, K, quantMode,
398399
nvinfer1::DataType::kFP8, mOutputType);
399400
cudaKernelFinished = tensorrt_llm::kernels::cuda_core_gemm::cudaCoreGemmDispatcher(params, stream);
400401
}
401-
else if (((mArch < 90 && M <= 6) || (mArch >= 90 && M <= 2)) && N <= 128000 && !mUseFp8 && noPadDim
402-
&& cudaKernelSupportType)
402+
else if (!isArch90or100 && ((mArch < 90 && M <= 6) || (isArch90or100 && M <= 2)) && N <= 128000 && !mUseFp8
403+
&& noPadDim && cudaKernelSupportType)
403404
{
404405
tensorrt_llm::common::QuantMode quantMode;
405406
tensorrt_llm::kernels::cuda_core_gemm::Params params(reinterpret_cast<void const*>(inputs[0]),

cpp/tensorrt_llm/plugins/mixtureOfExperts/mixtureOfExpertsPlugin.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -257,8 +257,9 @@ void MixtureOfExpertsPlugin::init()
257257
"MOE plugin only supports a different output type for FP4/FP8");
258258
TLLM_CHECK_WITH_INFO(mType != DataType::kFP8 || tensorrt_llm::common::getSMVersion() >= 89,
259259
"MoE FP8 is not supported for architectures less than SM89");
260-
TLLM_CHECK_WITH_INFO(mType != DataType::kFP4 || tensorrt_llm::common::getSMVersion() >= 100,
261-
"MoE FP4 is not supported for architectures less than SM100");
260+
TLLM_CHECK_WITH_INFO(mType != DataType::kFP4
261+
|| (tensorrt_llm::common::getSMVersion() >= 100 && tensorrt_llm::common::getSMVersion() < 120),
262+
"MoE FP4 is only supported on architecture SM100");
262263

263264
TLLM_CHECK_WITH_INFO(!hasLora() || mLoraType == mOutputType, "The LoraType need to keep same with moe OutputType.");
264265

cpp/tests/unit_tests/kernels/mixtureOfExpertsTest.cu

+10-6
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ protected:
161161
#endif
162162
bool should_skip_no_device = mDeviceCount <= 0;
163163
bool should_skip_unsupported_fp8 = getSMVersion() < 89 && FP8;
164-
bool should_skip_unsupported_fp4 = getSMVersion() < 100 && FP4;
164+
bool should_skip_unsupported_fp4 = (getSMVersion() < 100 || getSMVersion() >= 120) && FP4;
165165
return should_skip_no_device || should_skip_unsupported_fp8 || should_skip_unsupported_fp4;
166166
}
167167

@@ -862,7 +862,7 @@ protected:
862862
auto getFilteredConfigs(int sm)
863863
{
864864
auto tactics = mMoERunner.getTactics();
865-
if (sm == 89)
865+
if (sm == 89 || sm >= 120)
866866
{
867867
// Filter some unsupported configs for L40S
868868
auto it = std::remove_if(tactics.begin(), tactics.end(),
@@ -1308,7 +1308,8 @@ void MixtureOfExpertsTest<TypeParam_>::BasicPermuteTest(
13081308
auto [expected_experts, token_final_scales] = populateRouting(num_experts, num_tokens, k);
13091309
13101310
runMoEPermute(hidden_input, expected_experts, token_final_scales, hidden_size, num_experts, k);
1311-
bool should_be_deterministic = mUseDeterminsiticHopperReduce || mK < 3 || getSMVersion() < 90;
1311+
bool should_be_deterministic
1312+
= mUseDeterminsiticHopperReduce || mK < 3 || getSMVersion() < 90 || getSMVersion() >= 120;
13121313
if (should_be_deterministic && !mIsLongTest)
13131314
{
13141315
auto first_iter = getDataFromDevice(mFinalOutput, mTotalTokens * mHiddenSize);
@@ -1546,7 +1547,8 @@ void MixtureOfExpertsTest<TypeParam_>::ParallelismTest(
15461547
// Only need to init the inputs on the first iteration
15471548
runMoEPermute(hidden_input, expected_experts, token_final_scales, hidden_size, num_experts, k,
15481549
MOEParallelismConfig{tp_size, i, ep_size, j});
1549-
bool should_be_deterministic = mUseDeterminsiticHopperReduce || mK < 3 || getSMVersion() < 90;
1550+
bool should_be_deterministic
1551+
= mUseDeterminsiticHopperReduce || mK < 3 || getSMVersion() < 90 || getSMVersion() >= 120;
15501552
if (should_be_deterministic && !mIsLongTest)
15511553
{
15521554
auto first_iter = getDataFromDevice(mFinalOutput, mTotalTokens * mHiddenSize);
@@ -1560,7 +1562,8 @@ void MixtureOfExpertsTest<TypeParam_>::ParallelismTest(
15601562
else
15611563
{
15621564
runMoEPermute(MOEParallelismConfig{tp_size, i, ep_size, j});
1563-
bool should_be_deterministic = mUseDeterminsiticHopperReduce || mK < 3 || getSMVersion() < 90;
1565+
bool should_be_deterministic
1566+
= mUseDeterminsiticHopperReduce || mK < 3 || getSMVersion() < 90 || getSMVersion() >= 120;
15641567
if (should_be_deterministic && !mIsLongTest)
15651568
{
15661569
auto first_iter = getDataFromDevice(mFinalOutput, mTotalTokens * mHiddenSize);
@@ -1866,7 +1869,8 @@ TEST_F(MixtureOfExpertsProfilerTest, TestGeneratedProfilerDistribution)
18661869
18671870
backend.prepare(num_tokens, workspace, mStream->get());
18681871
1869-
auto getNext = backend.getWorkspacePointerGenerator(workspace, num_tokens, getSMVersion() >= 90);
1872+
auto getNext = backend.getWorkspacePointerGenerator(
1873+
workspace, num_tokens, getSMVersion() >= 90 && getSMVersion() < 120);
18701874
auto const* expert_first_token_offset_size = reinterpret_cast<int64_t*>(getNext());
18711875
auto const* source_to_dest_map = reinterpret_cast<int*>(getNext());
18721876
auto const* dest_to_source_map = reinterpret_cast<int*>(getNext());

cpp/tests/unit_tests/kernels/ropeTest.cu

+1-1
Original file line numberDiff line numberDiff line change
@@ -498,7 +498,7 @@ protected:
498498
if constexpr (std::is_same_v<KVCacheType, __nv_fp4_e2m1>)
499499
{
500500
// Quant helper functions will not work on lower SM versions.
501-
return getSMVersion() < 100;
501+
return getSMVersion() < 100 || getSMVersion() >= 120;
502502
}
503503
#endif
504504
return false;

0 commit comments

Comments
 (0)