@@ -4297,7 +4297,8 @@ static void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restri
4297
4297
4298
4298
static inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) {
4299
4299
if (j < 4) {
4300
- d = q[j] & 63; m = q[j + 4] & 63;
4300
+ d = q[j] & 63;
4301
+ m = q[j + 4] & 63;
4301
4302
} else {
4302
4303
d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
4303
4304
m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
@@ -4306,7 +4307,7 @@ static inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8
4306
4307
4307
4308
template<typename dst_t>
4308
4309
static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy,
4309
- const sycl::nd_item<3> &item_ct1) {
4310
+ uint8_t* scales_local, const sycl::nd_item<3> &item_ct1) {
4310
4311
const block_q4_K * x = (const block_q4_K *) vx;
4311
4312
4312
4313
const int i = item_ct1.get_group(2);
@@ -4320,19 +4321,26 @@ static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restri
4320
4321
4321
4322
dst_t * y = yy + i*QK_K + 64*il + n*ir;
4322
4323
4323
- const float dall = x[i].dm[0];
4324
- const float dmin = x[i].dm[1];
4324
+ const sycl::half2 dm = x[i].dm;
4325
+ const float dall = dm[0];
4326
+ const float dmin = dm[1];
4325
4327
4326
- const uint8_t * q = x[i].qs + 32*il + n*ir;
4328
+ if (tid < 12)
4329
+ scales_local[tid] = x[i].scales[tid];
4330
+ item_ct1.barrier(sycl::access::fence_space::local_space);
4327
4331
4328
4332
uint8_t sc, m;
4329
- get_scale_min_k4(is + 0, x[i].scales, sc, m);
4330
- const float d1 = dall * sc; const float m1 = dmin * m;
4331
- get_scale_min_k4(is + 1, x[i].scales, sc, m);
4332
- const float d2 = dall * sc; const float m2 = dmin * m;
4333
+ get_scale_min_k4(is + 0, scales_local, sc, m);
4334
+ const float d1 = dall * sc;
4335
+ const float m1 = dmin * m;
4336
+ get_scale_min_k4(is + 1, scales_local, sc, m);
4337
+ const float d2 = dall * sc;
4338
+ const float m2 = dmin * m;
4339
+
4340
+ sycl::vec<uint8_t, n> q_vec = reinterpret_cast<const sycl::vec<uint8_t, n>*>(x[i].qs + 32*il + n*ir)[0];
4333
4341
for (int l = 0; l < n; ++l) {
4334
- y[l + 0] = d1 * (q [l] & 0xF) - m1;
4335
- y[l +32] = d2 * (q [l] >> 4) - m2;
4342
+ y[l + 0] = d1 * (q_vec [l] & 0xF) - m1;
4343
+ y[l +32] = d2 * (q_vec [l] >> 4) - m2;
4336
4344
}
4337
4345
}
4338
4346
@@ -9888,12 +9896,15 @@ static void dequantize_row_q4_K_sycl(const void *vx, dst_t *y, const int k,
9888
9896
dpct::has_capability_or_fail(stream->get_device(),
9889
9897
{sycl::aspect::fp16});
9890
9898
9891
- stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
9899
+ stream->submit([&](sycl::handler &cgh) {
9900
+ sycl::local_accessor<uint8_t, 1> scale_local_acc(sycl::range<1>(12), cgh);
9901
+ cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
9892
9902
sycl::range<3>(1, 1, 32),
9893
9903
sycl::range<3>(1, 1, 32)),
9894
9904
[=](sycl::nd_item<3> item_ct1) {
9895
- dequantize_block_q4_K(vx, y, item_ct1);
9905
+ dequantize_block_q4_K(vx, y, scale_local_acc.get_pointer(), item_ct1);
9896
9906
});
9907
+ });
9897
9908
}
9898
9909
}
9899
9910
0 commit comments