@@ -113,8 +113,8 @@ typedef float2 dfloat2;
113
113
#endif // GGML_CUDA_DMMV_F16
114
114
115
115
typedef void (*dequantize_kernel_t )(const void * vx, const int ib, const int iqs, dfloat2 & v);
116
- typedef void (*to_fp32_cuda_t )(const void * x, float * y, int k, cudaStream_t stream);
117
- typedef void (*dot_kernel_k_t )(const void * vx, const int ib, const int iqs, const float * y, float & v);
116
+ typedef void (*to_fp32_cuda_t )(const void * __restrict__ x, float * __restrict__ y, int k, cudaStream_t stream);
117
+ typedef void (*dot_kernel_k_t )(const void * __restrict__ vx, const int ib, const int iqs, const float * __restrict__ y, float & v);
118
118
typedef void (*cpy_kernel_t )(const char * cx, char * cdst);
119
119
typedef void (*ggml_cuda_func_t )(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
120
120
typedef void (*ggml_cuda_op_t )(
@@ -185,7 +185,7 @@ typedef struct {
185
185
} block_q8_1;
186
186
static_assert (sizeof (block_q8_1) == 2*sizeof(ggml_fp16_t ) + QK8_0, "wrong q8_1 block size/padding");
187
187
188
- typedef float (*vec_dot_q_cuda_t )(const void * vbq, const block_q8_1 * bq8_1, const int iqs);
188
+ typedef float (*vec_dot_q_cuda_t )(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs);
189
189
190
190
// ================================= k-quants
191
191
@@ -461,7 +461,7 @@ static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const in
461
461
462
462
// ================================== k-quants
463
463
464
- static __global__ void dequantize_block_q2_K (const void * vx, float * yy) {
464
+ static __global__ void dequantize_block_q2_K (const void * __restrict__ vx, float * __restrict__ yy) {
465
465
466
466
const int i = blockIdx .x ;
467
467
const block_q2_K * x = (const block_q2_K *) vx;
@@ -494,7 +494,7 @@ static __global__ void dequantize_block_q2_K(const void * vx, float * yy) {
494
494
495
495
}
496
496
497
- static __global__ void dequantize_block_q3_K (const void * vx, float * yy) {
497
+ static __global__ void dequantize_block_q3_K (const void * __restrict__ vx, float * __restrict__ yy) {
498
498
499
499
const int i = blockIdx .x ;
500
500
const block_q3_K * x = (const block_q3_K *) vx;
@@ -558,7 +558,7 @@ static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t
558
558
}
559
559
#endif
560
560
561
- static __global__ void dequantize_block_q4_K (const void * vx, float * yy) {
561
+ static __global__ void dequantize_block_q4_K (const void * __restrict__ vx, float * __restrict__ yy) {
562
562
const block_q4_K * x = (const block_q4_K *) vx;
563
563
564
564
const int i = blockIdx .x ;
@@ -598,7 +598,7 @@ static __global__ void dequantize_block_q4_K(const void * vx, float * yy) {
598
598
#endif
599
599
}
600
600
601
- static __global__ void dequantize_block_q5_K (const void * vx, float * yy) {
601
+ static __global__ void dequantize_block_q5_K (const void * __restrict__ vx, float * __restrict__ yy) {
602
602
const block_q5_K * x = (const block_q5_K *) vx;
603
603
604
604
const int i = blockIdx .x ;
@@ -644,7 +644,7 @@ static __global__ void dequantize_block_q5_K(const void * vx, float * yy) {
644
644
#endif
645
645
}
646
646
647
- static __global__ void dequantize_block_q6_K (const void * vx, float * yy) {
647
+ static __global__ void dequantize_block_q6_K (const void * __restrict__ vx, float * __restrict__ yy) {
648
648
const block_q6_K * x = (const block_q6_K *) vx;
649
649
650
650
const int i = blockIdx .x ;
@@ -688,7 +688,7 @@ static __global__ void dequantize_block_q6_K(const void * vx, float * yy) {
688
688
#endif
689
689
}
690
690
691
- static __global__ void dequantize_mul_mat_vec_q2_k (const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
691
+ static __global__ void dequantize_mul_mat_vec_q2_k (const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
692
692
693
693
static_assert (16 %K_QUANTS_PER_ITERATION == 0 , " 16 must be divisible by K_QUANTS_PER_ITERATION" );
694
694
@@ -796,7 +796,7 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float
796
796
}
797
797
}
798
798
799
- static __global__ void dequantize_mul_mat_vec_q3_k (const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
799
+ static __global__ void dequantize_mul_mat_vec_q3_k (const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
800
800
801
801
const int row = blockIdx .y *blockDim .y + threadIdx .y ;
802
802
if (row > nrows) return ;
@@ -900,7 +900,7 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float
900
900
}
901
901
}
902
902
903
- static __global__ void dequantize_mul_mat_vec_q4_k (const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
903
+ static __global__ void dequantize_mul_mat_vec_q4_k (const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
904
904
905
905
const int row = blockIdx .y *blockDim .y + threadIdx .y ;
906
906
if (row > nrows) return ;
@@ -1003,7 +1003,7 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float
1003
1003
}
1004
1004
}
1005
1005
1006
- static __global__ void dequantize_mul_mat_vec_q5_k (const void * vx, const float * yy, float * dst, const int ncols) {
1006
+ static __global__ void dequantize_mul_mat_vec_q5_k (const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols) {
1007
1007
1008
1008
const int row = blockIdx .x ;
1009
1009
const int num_blocks_per_row = ncols / QK_K;
@@ -1107,7 +1107,7 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float
1107
1107
}
1108
1108
}
1109
1109
1110
- static __global__ void dequantize_mul_mat_vec_q6_k (const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
1110
+ static __global__ void dequantize_mul_mat_vec_q6_k (const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
1111
1111
1112
1112
static_assert (16 %K_QUANTS_PER_ITERATION == 0 , " 16 must be divisible by K_QUANTS_PER_ITERATION" );
1113
1113
@@ -1225,7 +1225,7 @@ static __device__ void convert_f16(const void * vx, const int ib, const int iqs,
1225
1225
v.y = x[ib + iqs + 1 ];
1226
1226
}
1227
1227
1228
- static __global__ void quantize_q8_1 (const float * x, void * vy, const int k) {
1228
+ static __global__ void quantize_q8_1 (const float * __restrict__ x, void * __restrict__ vy, const int k) {
1229
1229
const int i = blockDim .x *blockIdx .x + threadIdx .x ;
1230
1230
1231
1231
if (i >= k) {
@@ -1261,7 +1261,7 @@ static __global__ void quantize_q8_1(const float * x, void * vy, const int k) {
1261
1261
}
1262
1262
1263
1263
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
1264
- static __global__ void dequantize_block (const void * vx, float * y, const int k) {
1264
+ static __global__ void dequantize_block (const void * __restrict__ vx, float * __restrict__ y, const int k) {
1265
1265
const int i = blockDim .x *blockIdx .x + 2 *threadIdx .x ;
1266
1266
1267
1267
if (i >= k) {
@@ -1281,7 +1281,7 @@ static __global__ void dequantize_block(const void * vx, float * y, const int k)
1281
1281
y[iybs + iqs + y_offset] = v.y ;
1282
1282
}
1283
1283
1284
- static __device__ __forceinline__ float vec_dot_q4_0_q8_1 (const void * vbq, const block_q8_1 * bq8_1, const int iqs) {
1284
+ static __device__ __forceinline__ float vec_dot_q4_0_q8_1 (const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
1285
1285
#if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics
1286
1286
const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq;
1287
1287
@@ -1306,7 +1306,7 @@ static __device__ __forceinline__ float vec_dot_q4_0_q8_1(const void * vbq, cons
1306
1306
#endif // __CUDA_ARCH__ >= 600
1307
1307
}
1308
1308
1309
- static __device__ __forceinline__ float vec_dot_q4_1_q8_1 (const void * vbq, const block_q8_1 * bq8_1, const int iqs) {
1309
+ static __device__ __forceinline__ float vec_dot_q4_1_q8_1 (const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
1310
1310
#if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics
1311
1311
const block_q4_1 * bq4_1 = (const block_q4_1 *) vbq;
1312
1312
@@ -1331,7 +1331,7 @@ static __device__ __forceinline__ float vec_dot_q4_1_q8_1(const void * vbq, cons
1331
1331
#endif // __CUDA_ARCH__ >= 600
1332
1332
}
1333
1333
1334
- static __device__ __forceinline__ float vec_dot_q5_0_q8_1 (const void * vbq, const block_q8_1 * bq8_1, const int iqs) {
1334
+ static __device__ __forceinline__ float vec_dot_q5_0_q8_1 (const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
1335
1335
#if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics
1336
1336
const block_q5_0 * bq5_0 = (const block_q5_0 *) vbq;
1337
1337
@@ -1366,7 +1366,7 @@ static __device__ __forceinline__ float vec_dot_q5_0_q8_1(const void * vbq, cons
1366
1366
#endif // __CUDA_ARCH__ >= 600
1367
1367
}
1368
1368
1369
- static __device__ __forceinline__ float vec_dot_q5_1_q8_1 (const void * vbq, const block_q8_1 * bq8_1, const int iqs) {
1369
+ static __device__ __forceinline__ float vec_dot_q5_1_q8_1 (const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
1370
1370
#if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics
1371
1371
const block_q5_1 * bq5_1 = (const block_q5_1 *) vbq;
1372
1372
@@ -1400,7 +1400,7 @@ static __device__ __forceinline__ float vec_dot_q5_1_q8_1(const void * vbq, cons
1400
1400
#endif // __CUDA_ARCH__ >= 600
1401
1401
}
1402
1402
1403
- static __device__ __forceinline__ float vec_dot_q8_0_q8_1 (const void * vbq, const block_q8_1 * bq8_1, const int iqs) {
1403
+ static __device__ __forceinline__ float vec_dot_q8_0_q8_1 (const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
1404
1404
#if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics
1405
1405
const block_q8_0 * bq8_0 = (const block_q8_0 *) vbq;
1406
1406
@@ -1420,7 +1420,7 @@ static __device__ __forceinline__ float vec_dot_q8_0_q8_1(const void * vbq, cons
1420
1420
}
1421
1421
1422
1422
template <int qk, int qi, typename block_q_t , vec_dot_q_cuda_t vec_dot_q_cuda>
1423
- static __global__ void mul_mat_vec_q (const void * vx, const void * vy, float * dst, const int ncols, const int nrows) {
1423
+ static __global__ void mul_mat_vec_q (const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows) {
1424
1424
const int row = blockIdx .y *blockDim .y + threadIdx .y ;
1425
1425
1426
1426
if (row >= nrows) {
@@ -1458,7 +1458,7 @@ static __global__ void mul_mat_vec_q(const void * vx, const void * vy, float * d
1458
1458
}
1459
1459
1460
1460
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
1461
- static __global__ void dequantize_mul_mat_vec (const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows) {
1461
+ static __global__ void dequantize_mul_mat_vec (const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) {
1462
1462
// qk = quantized weights per x block
1463
1463
// qr = number of quantized weights per data value in x block
1464
1464
const int row = blockIdx .y *blockDim .y + threadIdx .y ;
@@ -1525,7 +1525,7 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const dfloat * y,
1525
1525
}
1526
1526
}
1527
1527
1528
- static __global__ void mul_mat_p021_f16_f32 (const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x) {
1528
+ static __global__ void mul_mat_p021_f16_f32 (const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x, const int nchannels_x) {
1529
1529
const half * x = (const half *) vx;
1530
1530
1531
1531
const int row_x = blockDim .y *blockIdx .y + threadIdx .y ;
@@ -1572,7 +1572,7 @@ static __global__ void mul_mat_p021_f16_f32(const void * vx, const float * y, fl
1572
1572
}
1573
1573
1574
1574
static __global__ void mul_mat_vec_nc_f16_f32 ( // nc == non-contiguous
1575
- const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x,
1575
+ const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x,
1576
1576
const int row_stride_x, const int channel_stride_x) {
1577
1577
1578
1578
const half * x = (const half *) vx;
@@ -2434,10 +2434,7 @@ inline void ggml_cuda_op_mul_mat_vec(
2434
2434
src0->type == GGML_TYPE_Q5_1 ||
2435
2435
src0->type == GGML_TYPE_Q8_0;
2436
2436
2437
- // The integer intrinsics used in mul_mat_vec_q are available with compute capability 6.
2438
- // However, they have bad performance with Pascal cards.
2439
- // Therefore, in a multi GPU setting decide at runtime which GPUs should use mul_mat_vec_q.
2440
- const bool use_mul_mat_vec_q = g_compute_capabilities[id] >= 700 && mul_mat_vec_q_implemented;
2437
+ const bool use_mul_mat_vec_q = g_compute_capabilities[id] >= 600 && mul_mat_vec_q_implemented;
2441
2438
#endif
2442
2439
2443
2440
if (use_mul_mat_vec_q) {
0 commit comments