Skip to content

Commit f864f60

Browse files
JohannesGaesslerYellowRoseCx
authored andcommitted
CUDA: add __restrict__ to mul mat vec kernels (ggml-org#2140)
1 parent 4539bc2 commit f864f60

File tree

1 file changed

+25
-28
lines changed

1 file changed

+25
-28
lines changed

ggml-cuda.cu

+25-28
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,8 @@ typedef float2 dfloat2;
113113
#endif //GGML_CUDA_DMMV_F16
114114

115115
typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v);
116-
typedef void (*to_fp32_cuda_t)(const void * x, float * y, int k, cudaStream_t stream);
117-
typedef void (*dot_kernel_k_t)(const void * vx, const int ib, const int iqs, const float * y, float & v);
116+
typedef void (*to_fp32_cuda_t)(const void * __restrict__ x, float * __restrict__ y, int k, cudaStream_t stream);
117+
typedef void (*dot_kernel_k_t)(const void * __restrict__ vx, const int ib, const int iqs, const float * __restrict__ y, float & v);
118118
typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
119119
typedef void (*ggml_cuda_func_t)(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
120120
typedef void (*ggml_cuda_op_t)(
@@ -185,7 +185,7 @@ typedef struct {
185185
} block_q8_1;
186186
static_assert(sizeof(block_q8_1) == 2*sizeof(ggml_fp16_t) + QK8_0, "wrong q8_1 block size/padding");
187187

188-
typedef float (*vec_dot_q_cuda_t)(const void * vbq, const block_q8_1 * bq8_1, const int iqs);
188+
typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs);
189189

190190
//================================= k-quants
191191

@@ -461,7 +461,7 @@ static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const in
461461

462462
//================================== k-quants
463463

464-
static __global__ void dequantize_block_q2_K(const void * vx, float * yy) {
464+
static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, float * __restrict__ yy) {
465465

466466
const int i = blockIdx.x;
467467
const block_q2_K * x = (const block_q2_K *) vx;
@@ -494,7 +494,7 @@ static __global__ void dequantize_block_q2_K(const void * vx, float * yy) {
494494

495495
}
496496

497-
static __global__ void dequantize_block_q3_K(const void * vx, float * yy) {
497+
static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, float * __restrict__ yy) {
498498

499499
const int i = blockIdx.x;
500500
const block_q3_K * x = (const block_q3_K *) vx;
@@ -558,7 +558,7 @@ static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t
558558
}
559559
#endif
560560

561-
static __global__ void dequantize_block_q4_K(const void * vx, float * yy) {
561+
static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, float * __restrict__ yy) {
562562
const block_q4_K * x = (const block_q4_K *) vx;
563563

564564
const int i = blockIdx.x;
@@ -598,7 +598,7 @@ static __global__ void dequantize_block_q4_K(const void * vx, float * yy) {
598598
#endif
599599
}
600600

601-
static __global__ void dequantize_block_q5_K(const void * vx, float * yy) {
601+
static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, float * __restrict__ yy) {
602602
const block_q5_K * x = (const block_q5_K *) vx;
603603

604604
const int i = blockIdx.x;
@@ -644,7 +644,7 @@ static __global__ void dequantize_block_q5_K(const void * vx, float * yy) {
644644
#endif
645645
}
646646

647-
static __global__ void dequantize_block_q6_K(const void * vx, float * yy) {
647+
static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, float * __restrict__ yy) {
648648
const block_q6_K * x = (const block_q6_K *) vx;
649649

650650
const int i = blockIdx.x;
@@ -688,7 +688,7 @@ static __global__ void dequantize_block_q6_K(const void * vx, float * yy) {
688688
#endif
689689
}
690690

691-
static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
691+
static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
692692

693693
static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
694694

@@ -796,7 +796,7 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float
796796
}
797797
}
798798

799-
static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
799+
static __global__ void dequantize_mul_mat_vec_q3_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
800800

801801
const int row = blockIdx.y*blockDim.y + threadIdx.y;
802802
if (row > nrows) return;
@@ -900,7 +900,7 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float
900900
}
901901
}
902902

903-
static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
903+
static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
904904

905905
const int row = blockIdx.y*blockDim.y + threadIdx.y;
906906
if (row > nrows) return;
@@ -1003,7 +1003,7 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float
10031003
}
10041004
}
10051005

1006-
static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float * yy, float * dst, const int ncols) {
1006+
static __global__ void dequantize_mul_mat_vec_q5_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols) {
10071007

10081008
const int row = blockIdx.x;
10091009
const int num_blocks_per_row = ncols / QK_K;
@@ -1107,7 +1107,7 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float
11071107
}
11081108
}
11091109

1110-
static __global__ void dequantize_mul_mat_vec_q6_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
1110+
static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
11111111

11121112
static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
11131113

@@ -1225,7 +1225,7 @@ static __device__ void convert_f16(const void * vx, const int ib, const int iqs,
12251225
v.y = x[ib + iqs + 1];
12261226
}
12271227

1228-
static __global__ void quantize_q8_1(const float * x, void * vy, const int k) {
1228+
static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int k) {
12291229
const int i = blockDim.x*blockIdx.x + threadIdx.x;
12301230

12311231
if (i >= k) {
@@ -1261,7 +1261,7 @@ static __global__ void quantize_q8_1(const float * x, void * vy, const int k) {
12611261
}
12621262

12631263
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
1264-
static __global__ void dequantize_block(const void * vx, float * y, const int k) {
1264+
static __global__ void dequantize_block(const void * __restrict__ vx, float * __restrict__ y, const int k) {
12651265
const int i = blockDim.x*blockIdx.x + 2*threadIdx.x;
12661266

12671267
if (i >= k) {
@@ -1281,7 +1281,7 @@ static __global__ void dequantize_block(const void * vx, float * y, const int k)
12811281
y[iybs + iqs + y_offset] = v.y;
12821282
}
12831283

1284-
static __device__ __forceinline__ float vec_dot_q4_0_q8_1(const void * vbq, const block_q8_1 * bq8_1, const int iqs) {
1284+
static __device__ __forceinline__ float vec_dot_q4_0_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
12851285
#if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics
12861286
const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq;
12871287

@@ -1306,7 +1306,7 @@ static __device__ __forceinline__ float vec_dot_q4_0_q8_1(const void * vbq, cons
13061306
#endif // __CUDA_ARCH__ >= 600
13071307
}
13081308

1309-
static __device__ __forceinline__ float vec_dot_q4_1_q8_1(const void * vbq, const block_q8_1 * bq8_1, const int iqs) {
1309+
static __device__ __forceinline__ float vec_dot_q4_1_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
13101310
#if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics
13111311
const block_q4_1 * bq4_1 = (const block_q4_1 *) vbq;
13121312

@@ -1331,7 +1331,7 @@ static __device__ __forceinline__ float vec_dot_q4_1_q8_1(const void * vbq, cons
13311331
#endif // __CUDA_ARCH__ >= 600
13321332
}
13331333

1334-
static __device__ __forceinline__ float vec_dot_q5_0_q8_1(const void * vbq, const block_q8_1 * bq8_1, const int iqs) {
1334+
static __device__ __forceinline__ float vec_dot_q5_0_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
13351335
#if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics
13361336
const block_q5_0 * bq5_0 = (const block_q5_0 *) vbq;
13371337

@@ -1366,7 +1366,7 @@ static __device__ __forceinline__ float vec_dot_q5_0_q8_1(const void * vbq, cons
13661366
#endif // __CUDA_ARCH__ >= 600
13671367
}
13681368

1369-
static __device__ __forceinline__ float vec_dot_q5_1_q8_1(const void * vbq, const block_q8_1 * bq8_1, const int iqs) {
1369+
static __device__ __forceinline__ float vec_dot_q5_1_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
13701370
#if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics
13711371
const block_q5_1 * bq5_1 = (const block_q5_1 *) vbq;
13721372

@@ -1400,7 +1400,7 @@ static __device__ __forceinline__ float vec_dot_q5_1_q8_1(const void * vbq, cons
14001400
#endif // __CUDA_ARCH__ >= 600
14011401
}
14021402

1403-
static __device__ __forceinline__ float vec_dot_q8_0_q8_1(const void * vbq, const block_q8_1 * bq8_1, const int iqs) {
1403+
static __device__ __forceinline__ float vec_dot_q8_0_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
14041404
#if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics
14051405
const block_q8_0 * bq8_0 = (const block_q8_0 *) vbq;
14061406

@@ -1420,7 +1420,7 @@ static __device__ __forceinline__ float vec_dot_q8_0_q8_1(const void * vbq, cons
14201420
}
14211421

14221422
template <int qk, int qi, typename block_q_t, vec_dot_q_cuda_t vec_dot_q_cuda>
1423-
static __global__ void mul_mat_vec_q(const void * vx, const void * vy, float * dst, const int ncols, const int nrows) {
1423+
static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows) {
14241424
const int row = blockIdx.y*blockDim.y + threadIdx.y;
14251425

14261426
if (row >= nrows) {
@@ -1458,7 +1458,7 @@ static __global__ void mul_mat_vec_q(const void * vx, const void * vy, float * d
14581458
}
14591459

14601460
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
1461-
static __global__ void dequantize_mul_mat_vec(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows) {
1461+
static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) {
14621462
// qk = quantized weights per x block
14631463
// qr = number of quantized weights per data value in x block
14641464
const int row = blockIdx.y*blockDim.y + threadIdx.y;
@@ -1525,7 +1525,7 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const dfloat * y,
15251525
}
15261526
}
15271527

1528-
static __global__ void mul_mat_p021_f16_f32(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x) {
1528+
static __global__ void mul_mat_p021_f16_f32(const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x, const int nchannels_x) {
15291529
const half * x = (const half *) vx;
15301530

15311531
const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
@@ -1572,7 +1572,7 @@ static __global__ void mul_mat_p021_f16_f32(const void * vx, const float * y, fl
15721572
}
15731573

15741574
static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
1575-
const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x,
1575+
const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x,
15761576
const int row_stride_x, const int channel_stride_x) {
15771577

15781578
const half * x = (const half *) vx;
@@ -2434,10 +2434,7 @@ inline void ggml_cuda_op_mul_mat_vec(
24342434
src0->type == GGML_TYPE_Q5_1 ||
24352435
src0->type == GGML_TYPE_Q8_0;
24362436

2437-
// The integer intrinsics used in mul_mat_vec_q are available with compute capability 6.
2438-
// However, they have bad performance with Pascal cards.
2439-
// Therefore, in a multi GPU setting decide at runtime which GPUs should use mul_mat_vec_q.
2440-
const bool use_mul_mat_vec_q = g_compute_capabilities[id] >= 700 && mul_mat_vec_q_implemented;
2437+
const bool use_mul_mat_vec_q = g_compute_capabilities[id] >= 600 && mul_mat_vec_q_implemented;
24412438
#endif
24422439

24432440
if (use_mul_mat_vec_q) {

0 commit comments

Comments
 (0)