Skip to content

Commit b46fde8

Browse files
authored
Merge branch 'main' into yunhsuanc/offload_ptable_draft
2 parents 27e4969 + 5bdf997 commit b46fde8

File tree

22 files changed

+766
-102
lines changed

22 files changed

+766
-102
lines changed

cpp/tensorrt_llm/common/cudaFp8Utils.cu

+40-9
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#include "tensorrt_llm/common/cudaFp8Utils.h"
1818
#include "tensorrt_llm/common/cudaUtils.h"
19+
#include "tensorrt_llm/common/envUtils.h"
1920
#include "tensorrt_llm/common/reduceKernelUtils.cuh"
2021
#include <algorithm>
2122
#include <cstdio>
@@ -40,6 +41,10 @@ __inline__ __device__ float scale(float a, float b)
4041
template <QuantizeMode QUANTIZE_MODE, bool QUANTIZE, typename T_OUT, typename T_S, typename T_IN>
4142
__global__ void scaleMatrix(T_OUT* output, T_S const* input_scale, T_IN const* input, int64_t numel, int64_t lda)
4243
{
44+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
45+
asm volatile("griddepcontrol.wait;");
46+
#endif
47+
4348
for (int64_t i = threadIdx.x + blockIdx.x * blockDim.x; i < numel; i += blockDim.x * gridDim.x)
4449
{
4550

@@ -56,6 +61,9 @@ __global__ void scaleMatrix(T_OUT* output, T_S const* input_scale, T_IN const* i
5661
output[i] = T_OUT(scale<QUANTIZE>(static_cast<float>(input[i]), static_cast<float>(input_scale[0])));
5762
}
5863
}
64+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
65+
asm volatile("griddepcontrol.launch_dependents;");
66+
#endif
5967
}
6068

6169
template <typename T_OUT, typename T_S, typename T_IN>
@@ -64,18 +72,30 @@ void invokeQuantizeMatrix(T_OUT* output, T_S const* input_scale, T_IN const* inp
6472
{
6573
dim3 grid(1024);
6674
dim3 block(CTA_SIZE);
75+
cudaLaunchConfig_t config;
76+
config.gridDim = grid;
77+
config.blockDim = block;
78+
config.dynamicSmemBytes = 0;
79+
config.stream = stream;
80+
cudaLaunchAttribute attrs[1];
81+
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
82+
attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL();
83+
config.numAttrs = 1;
84+
config.attrs = attrs;
6785
if (quantize_mode == QuantizeMode::PER_CHANNEL)
6886
{
69-
scaleMatrix<QuantizeMode::PER_CHANNEL, true>
70-
<<<grid, block, 0, stream>>>(output, input_scale, input, numel, lda);
87+
cudaLaunchKernelEx(&config, scaleMatrix<QuantizeMode::PER_CHANNEL, true, T_OUT, T_S, T_IN>, output, input_scale,
88+
input, numel, lda);
7189
}
7290
else if (quantize_mode == QuantizeMode::PER_TOKEN)
7391
{
74-
scaleMatrix<QuantizeMode::PER_TOKEN, true><<<grid, block, 0, stream>>>(output, input_scale, input, numel, lda);
92+
cudaLaunchKernelEx(&config, scaleMatrix<QuantizeMode::PER_TOKEN, true, T_OUT, T_S, T_IN>, output, input_scale,
93+
input, numel, lda);
7594
}
7695
else if (quantize_mode == QuantizeMode::PER_TENSOR)
7796
{
78-
scaleMatrix<QuantizeMode::PER_TENSOR, true><<<grid, block, 0, stream>>>(output, input_scale, input, numel, lda);
97+
cudaLaunchKernelEx(&config, scaleMatrix<QuantizeMode::PER_TENSOR, true, T_OUT, T_S, T_IN>, output, input_scale,
98+
input, numel, lda);
7999
}
80100
sync_check_cuda_error(stream);
81101
}
@@ -86,19 +106,30 @@ void invokeDequantizeMatrix(T_OUT* output, T_S const* input_scale, T_IN const* i
86106
{
87107
dim3 grid(1024);
88108
dim3 block(CTA_SIZE);
109+
cudaLaunchConfig_t config;
110+
config.gridDim = grid;
111+
config.blockDim = block;
112+
config.dynamicSmemBytes = 0;
113+
config.stream = stream;
114+
cudaLaunchAttribute attrs[1];
115+
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
116+
attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL();
117+
config.numAttrs = 1;
118+
config.attrs = attrs;
89119
if (quantize_mode == QuantizeMode::PER_CHANNEL)
90120
{
91-
scaleMatrix<QuantizeMode::PER_CHANNEL, false>
92-
<<<grid, block, 0, stream>>>(output, input_scale, input, numel, lda);
121+
cudaLaunchKernelEx(&config, scaleMatrix<QuantizeMode::PER_CHANNEL, false, T_OUT, T_S, T_IN>, output,
122+
input_scale, input, numel, lda);
93123
}
94124
else if (quantize_mode == QuantizeMode::PER_TOKEN)
95125
{
96-
scaleMatrix<QuantizeMode::PER_TOKEN, false><<<grid, block, 0, stream>>>(output, input_scale, input, numel, lda);
126+
cudaLaunchKernelEx(&config, scaleMatrix<QuantizeMode::PER_TOKEN, false, T_OUT, T_S, T_IN>, output, input_scale,
127+
input, numel, lda);
97128
}
98129
else if (quantize_mode == QuantizeMode::PER_TENSOR)
99130
{
100-
scaleMatrix<QuantizeMode::PER_TENSOR, false>
101-
<<<grid, block, 0, stream>>>(output, input_scale, input, numel, lda);
131+
cudaLaunchKernelEx(&config, scaleMatrix<QuantizeMode::PER_TENSOR, false, T_OUT, T_S, T_IN>, output, input_scale,
132+
input, numel, lda);
102133
}
103134
sync_check_cuda_error(stream);
104135
}

cpp/tensorrt_llm/kernels/unfusedAttentionKernels/unfusedAttentionKernels_2_template.h

+29-6
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "tensorrt_llm/common/assert.h"
2121
#include "tensorrt_llm/common/cudaTypeUtils.cuh"
2222
#include "tensorrt_llm/common/cudaUtils.h"
23+
#include "tensorrt_llm/common/envUtils.h"
2324
#include "tensorrt_llm/common/reduceKernelUtils.cuh"
2425
#include "tensorrt_llm/kernels/decoderMaskedMultiheadAttentionUtils.h"
2526
#include "tensorrt_llm/kernels/gptKernels.h"
@@ -778,6 +779,9 @@ __global__ void applyBiasRopeUpdateKVCacheV2(QKVPreprocessingParams<T, KVCacheBu
778779

779780
// Head idx.
780781
int const head_idx = blockIdx.y;
782+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
783+
asm volatile("griddepcontrol.wait;");
784+
#endif
781785

782786
// Variable sequence length.
783787
bool const variable_sequence_length = params.tokens_info != nullptr && params.cu_seq_lens != nullptr;
@@ -1093,6 +1097,9 @@ __global__ void applyBiasRopeUpdateKVCacheV2(QKVPreprocessingParams<T, KVCacheBu
10931097
params.fmha_bmm2_scale[0] = o_scale_orig_quant * kv_scale_quant_orig;
10941098
}
10951099
}
1100+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
1101+
asm volatile("griddepcontrol.launch_dependents;");
1102+
#endif
10961103
}
10971104

10981105
// Use more blocks for the batch dimension in the generation phase.
@@ -1255,22 +1262,38 @@ void kernelV1Dispatch(QKVPreprocessingParams<T, KVCacheBuffer> params, cudaStrea
12551262
dim3 grid(1, params.head_num); \
12561263
int num_blocks_for_tokens = int(divUp(params.token_num, tokens_per_cuda_block)); \
12571264
calGridSizeWithBestEfficiency(block, grid, num_blocks_for_tokens, params.multi_processor_count, 1024); \
1265+
cudaLaunchConfig_t config; \
1266+
config.gridDim = grid; \
1267+
config.blockDim = block; \
1268+
config.dynamicSmemBytes = 0; \
1269+
config.stream = stream; \
1270+
cudaLaunchAttribute attrs[1]; \
1271+
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; \
1272+
attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL(); \
1273+
config.numAttrs = 1; \
1274+
config.attrs = attrs; \
12581275
if (params.position_embedding_type == PositionEmbeddingType::kROPE_GPT_NEOX \
12591276
|| params.position_embedding_type == PositionEmbeddingType::kLONG_ROPE \
12601277
|| params.position_embedding_type == PositionEmbeddingType::kROPE_M) \
12611278
{ \
1262-
applyBiasRopeUpdateKVCacheV2<T, TCache, BLOCK_SIZE, Dh, ADD_BIAS, STORE_QKV, FP8_OUTPUT, GEN_PHASE, \
1263-
KVCacheBuffer, RotaryPositionEmbeddingType::GPT_NEOX><<<grid, block, 0, stream>>>(params); \
1279+
cudaLaunchKernelEx(&config, \
1280+
applyBiasRopeUpdateKVCacheV2<T, TCache, BLOCK_SIZE, Dh, ADD_BIAS, STORE_QKV, FP8_OUTPUT, GEN_PHASE, \
1281+
KVCacheBuffer, RotaryPositionEmbeddingType::GPT_NEOX>, \
1282+
params); \
12641283
} \
12651284
else if (params.position_embedding_type == PositionEmbeddingType::kROPE_GPTJ) \
12661285
{ \
1267-
applyBiasRopeUpdateKVCacheV2<T, TCache, BLOCK_SIZE, Dh, ADD_BIAS, STORE_QKV, FP8_OUTPUT, GEN_PHASE, \
1268-
KVCacheBuffer, RotaryPositionEmbeddingType::GPTJ><<<grid, block, 0, stream>>>(params); \
1286+
cudaLaunchKernelEx(&config, \
1287+
applyBiasRopeUpdateKVCacheV2<T, TCache, BLOCK_SIZE, Dh, ADD_BIAS, STORE_QKV, FP8_OUTPUT, GEN_PHASE, \
1288+
KVCacheBuffer, RotaryPositionEmbeddingType::GPTJ>, \
1289+
params); \
12691290
} \
12701291
else \
12711292
{ \
1272-
applyBiasRopeUpdateKVCacheV2<T, TCache, BLOCK_SIZE, Dh, ADD_BIAS, STORE_QKV, FP8_OUTPUT, GEN_PHASE, \
1273-
KVCacheBuffer, RotaryPositionEmbeddingType::NONE><<<grid, block, 0, stream>>>(params); \
1293+
cudaLaunchKernelEx(&config, \
1294+
applyBiasRopeUpdateKVCacheV2<T, TCache, BLOCK_SIZE, Dh, ADD_BIAS, STORE_QKV, FP8_OUTPUT, GEN_PHASE, \
1295+
KVCacheBuffer, RotaryPositionEmbeddingType::NONE>, \
1296+
params); \
12741297
}
12751298

12761299
#define STORE_QKV_AND_FP8_OUTPUT_DISPATCH(ADD_BIAS, GEN_PHASE) \

cpp/tensorrt_llm/thop/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ add_library(
5252
fp4BatchedQuantize.cpp
5353
fp8BlockScalingGemm.cpp
5454
fp8Quantize.cpp
55+
fusedTopkSoftmax.cpp
5556
gatherTreeOp.cpp
5657
logitsBitmaskOp.cpp
5758
mambaConv1dOp.cpp
+72
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
/*
2+
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include "tensorrt_llm/common/cudaUtils.h"
18+
#include "tensorrt_llm/common/workspace.h"
19+
#include "tensorrt_llm/kernels/internal_cutlass_kernels/include/moe_gemm_kernels.h"
20+
#include "tensorrt_llm/kernels/internal_cutlass_kernels/include/moe_kernels.h"
21+
#include "tensorrt_llm/runtime/torchUtils.h"
22+
#include "tensorrt_llm/thop/thUtils.h"
23+
24+
#include <ATen/cuda/EmptyTensor.h>
25+
26+
#include <cuda_fp16.h>
27+
28+
#include <cstdint>
29+
30+
namespace torch_ext
31+
{
32+
33+
std::tuple<torch::Tensor, torch::Tensor> fused_topk_softmax(torch::Tensor const& router_logits, int64_t const top_k,
34+
int64_t const num_experts_total, int64_t const start_expert, int64_t const end_expert)
35+
{
36+
// TODO: enable once the kernel has been added to the internal CUTLASS library.
37+
TLLM_CHECK_WITH_INFO(false, "Fused topk/softmax op has not been enabled yet.");
38+
39+
CHECK_INPUT(router_logits, torch::kBFloat16);
40+
41+
auto const& router_logits_shape = router_logits.sizes();
42+
auto const& rank = router_logits_shape.size();
43+
44+
TORCH_CHECK(rank == 2, "router_logits should be 2D tensor.");
45+
int64_t const num_rows = router_logits_shape[0];
46+
47+
auto token_final_scales
48+
= torch::empty({num_rows, top_k}, torch::dtype(torch::kFloat32).device(router_logits.device()));
49+
auto token_selected_experts
50+
= torch::empty({num_rows, top_k}, torch::dtype(torch::kInt32).device(router_logits.device()));
51+
52+
// auto stream = at::cuda::getCurrentCUDAStream(router_logits.get_device());
53+
// tensorrt_llm::kernels::topkGatingSoftmaxKernelLauncher(
54+
// static_cast<__nv_bfloat16 const*>(router_logits.const_data_ptr()),
55+
// static_cast<float*>(token_final_scales.data_ptr()), static_cast<int*>(token_selected_experts.data_ptr()),
56+
// num_rows, top_k, num_experts_total, start_expert, end_expert, stream);
57+
return {token_final_scales, token_selected_experts};
58+
}
59+
} // namespace torch_ext
60+
61+
TORCH_LIBRARY_FRAGMENT(trtllm, m)
62+
{
63+
m.def(
64+
"fused_topk_softmax(Tensor router_logits, int top_k, "
65+
"int num_experts_total, int start_expert, "
66+
"int end_expert) -> (Tensor, Tensor) ");
67+
}
68+
69+
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
70+
{
71+
m.impl("fused_topk_softmax", &torch_ext::fused_topk_softmax);
72+
}

examples/mllama/requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
nvidia-modelopt[torch]~=0.21.0
2+
transformers==4.48.3

examples/pytorch/README.md

+1
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ python3 quickstart_multimodal.py --model_dir Efficient-Large-Model/NVILA-8B --mo
5555
| `Qwen2ForProcessRewardModel` | Qwen2-based | `Qwen/Qwen2.5-Math-PRM-7B` | L |
5656
| `Qwen2ForRewardModel` | Qwen2-based | `Qwen/Qwen2.5-Math-RM-72B` | L |
5757
| `Qwen2VLForConditionalGeneration` | Qwen2-VL | `Qwen/Qwen2-VL-7B-Instruct` | L + V |
58+
| `Llama4ForConditionalGeneration` | Llama 4 | `meta-llama/Llama-4-Scout-17B-16E-Instruct` | L |
5859

5960
Note:
6061
- L: Language only

requirements.txt

+2-2
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,10 @@ tensorrt~=10.8.0
2323
# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-25-01.html#rel-25-01 uses 2.6.0a0.
2424
torch>=2.6.0a0,<=2.6.0
2525
torchvision
26-
nvidia-modelopt[torch]~=0.25.0
26+
nvidia-modelopt[torch]~=0.27.0
2727
nvidia-nccl-cu12
2828
nvidia-cuda-nvrtc-cu12
29-
transformers==4.48.3
29+
transformers==4.51.0
3030
pydantic>=2.9.1
3131
pillow==10.3.0
3232
wheel

tensorrt_llm/_torch/model_config.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,8 @@ def from_pretrained(cls,
8989
# Find the cache path by looking for the config.json file which should be in all
9090
# huggingface models
9191
model_dir = Path(
92-
transformers.file_utils.get_file_from_repo(checkpoint_dir,
93-
'config.json')).parent
92+
transformers.utils.hub.cached_file(checkpoint_dir,
93+
'config.json')).parent
9494
quant_config = QuantConfig()
9595
layer_quant_config = None
9696
# quantized ckpt in modelopt format

0 commit comments

Comments
 (0)