16
16
17
17
#include " tensorrt_llm/common/cudaFp8Utils.h"
18
18
#include " tensorrt_llm/common/cudaUtils.h"
19
+ #include " tensorrt_llm/common/envUtils.h"
19
20
#include " tensorrt_llm/common/reduceKernelUtils.cuh"
20
21
#include < algorithm>
21
22
#include < cstdio>
@@ -40,6 +41,10 @@ __inline__ __device__ float scale(float a, float b)
40
41
template <QuantizeMode QUANTIZE_MODE, bool QUANTIZE, typename T_OUT, typename T_S, typename T_IN>
41
42
__global__ void scaleMatrix (T_OUT* output, T_S const * input_scale, T_IN const * input, int64_t numel, int64_t lda)
42
43
{
44
+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
45
+ asm volatile (" griddepcontrol.wait;" );
46
+ #endif
47
+
43
48
for (int64_t i = threadIdx .x + blockIdx .x * blockDim .x ; i < numel; i += blockDim .x * gridDim .x )
44
49
{
45
50
@@ -56,6 +61,9 @@ __global__ void scaleMatrix(T_OUT* output, T_S const* input_scale, T_IN const* i
56
61
output[i] = T_OUT (scale<QUANTIZE>(static_cast <float >(input[i]), static_cast <float >(input_scale[0 ])));
57
62
}
58
63
}
64
+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
65
+ asm volatile (" griddepcontrol.launch_dependents;" );
66
+ #endif
59
67
}
60
68
61
69
template <typename T_OUT, typename T_S, typename T_IN>
@@ -64,18 +72,30 @@ void invokeQuantizeMatrix(T_OUT* output, T_S const* input_scale, T_IN const* inp
64
72
{
65
73
dim3 grid (1024 );
66
74
dim3 block (CTA_SIZE);
75
+ cudaLaunchConfig_t config;
76
+ config.gridDim = grid;
77
+ config.blockDim = block;
78
+ config.dynamicSmemBytes = 0 ;
79
+ config.stream = stream;
80
+ cudaLaunchAttribute attrs[1 ];
81
+ attrs[0 ].id = cudaLaunchAttributeProgrammaticStreamSerialization;
82
+ attrs[0 ].val .programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL ();
83
+ config.numAttrs = 1 ;
84
+ config.attrs = attrs;
67
85
if (quantize_mode == QuantizeMode::PER_CHANNEL)
68
86
{
69
- scaleMatrix<QuantizeMode::PER_CHANNEL, true >
70
- <<<grid, block, 0 , stream>>> (output, input_scale, input, numel, lda);
87
+ cudaLaunchKernelEx (&config, scaleMatrix<QuantizeMode::PER_CHANNEL, true , T_OUT, T_S, T_IN>, output, input_scale,
88
+ input, numel, lda);
71
89
}
72
90
else if (quantize_mode == QuantizeMode::PER_TOKEN)
73
91
{
74
- scaleMatrix<QuantizeMode::PER_TOKEN, true ><<<grid, block, 0 , stream>>> (output, input_scale, input, numel, lda);
92
+ cudaLaunchKernelEx (&config, scaleMatrix<QuantizeMode::PER_TOKEN, true , T_OUT, T_S, T_IN>, output, input_scale,
93
+ input, numel, lda);
75
94
}
76
95
else if (quantize_mode == QuantizeMode::PER_TENSOR)
77
96
{
78
- scaleMatrix<QuantizeMode::PER_TENSOR, true ><<<grid, block, 0 , stream>>> (output, input_scale, input, numel, lda);
97
+ cudaLaunchKernelEx (&config, scaleMatrix<QuantizeMode::PER_TENSOR, true , T_OUT, T_S, T_IN>, output, input_scale,
98
+ input, numel, lda);
79
99
}
80
100
sync_check_cuda_error (stream);
81
101
}
@@ -86,19 +106,30 @@ void invokeDequantizeMatrix(T_OUT* output, T_S const* input_scale, T_IN const* i
86
106
{
87
107
dim3 grid (1024 );
88
108
dim3 block (CTA_SIZE);
109
+ cudaLaunchConfig_t config;
110
+ config.gridDim = grid;
111
+ config.blockDim = block;
112
+ config.dynamicSmemBytes = 0 ;
113
+ config.stream = stream;
114
+ cudaLaunchAttribute attrs[1 ];
115
+ attrs[0 ].id = cudaLaunchAttributeProgrammaticStreamSerialization;
116
+ attrs[0 ].val .programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL ();
117
+ config.numAttrs = 1 ;
118
+ config.attrs = attrs;
89
119
if (quantize_mode == QuantizeMode::PER_CHANNEL)
90
120
{
91
- scaleMatrix<QuantizeMode::PER_CHANNEL, false >
92
- <<<grid, block, 0 , stream>>> (output, input_scale, input, numel, lda);
121
+ cudaLaunchKernelEx (&config, scaleMatrix<QuantizeMode::PER_CHANNEL, false , T_OUT, T_S, T_IN>, output,
122
+ input_scale, input, numel, lda);
93
123
}
94
124
else if (quantize_mode == QuantizeMode::PER_TOKEN)
95
125
{
96
- scaleMatrix<QuantizeMode::PER_TOKEN, false ><<<grid, block, 0 , stream>>> (output, input_scale, input, numel, lda);
126
+ cudaLaunchKernelEx (&config, scaleMatrix<QuantizeMode::PER_TOKEN, false , T_OUT, T_S, T_IN>, output, input_scale,
127
+ input, numel, lda);
97
128
}
98
129
else if (quantize_mode == QuantizeMode::PER_TENSOR)
99
130
{
100
- scaleMatrix<QuantizeMode::PER_TENSOR, false >
101
- <<<grid, block, 0 , stream>>> (output, input_scale, input, numel, lda);
131
+ cudaLaunchKernelEx (&config, scaleMatrix<QuantizeMode::PER_TENSOR, false , T_OUT, T_S, T_IN>, output, input_scale,
132
+ input, numel, lda);
102
133
}
103
134
sync_check_cuda_error (stream);
104
135
}
0 commit comments