Skip to content

Commit 0a321fc

Browse files
committed
Fix too many shader groups called validation error in llama3 on AMD and Intel GPUs
1 parent d63aca3 commit 0a321fc

8 files changed

+19802
-19336
lines changed

ggml-vulkan-shaders.hpp

+19,771-19,322
Large diffs are not rendered by default.

ggml-vulkan.cpp

+25-8
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,6 @@ struct vk_context {
384384
};
385385

386386
struct ggml_tensor_extra_gpu {
387-
ggml_backend_vk_context * backend_ctx;
388387
size_t ctx_idx;
389388

390389
vk_buffer_ref buffer_gpu;
@@ -2746,9 +2745,6 @@ static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& sr
27462745
ggml_vk_ensure_sync_staging_buffer(src->device, size);
27472746
ggml_vk_ensure_sync_staging_buffer(dst->device, size);
27482747

2749-
std::lock_guard<std::mutex> src_lock(src->device->mutex);
2750-
std::lock_guard<std::mutex> dst_lock(dst->device->mutex);
2751-
27522748
// Copy to src staging buffer
27532749
ggml_vk_buffer_copy(src->device->sync_staging, 0, src, src_offset, size);
27542750
// memcpy to dst staging buffer
@@ -3228,18 +3224,30 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context
32283224
stride_batch_y = src1->nb[0] / ggml_type_size(src1->type);
32293225
}
32303226

3227+
const uint32_t max_groups_x = ctx->device->properties.limits.maxComputeWorkGroupCount[0];
3228+
3229+
uint32_t groups_x = ne01;
3230+
uint32_t groups_z = 1;
3231+
3232+
if (ne01 > max_groups_x) {
3233+
groups_z = 64;
3234+
groups_x /= groups_z;
3235+
}
3236+
32313237
// compute
32323238
const vk_mat_vec_push_constants pc = {
32333239
(uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
32343240
stride_batch_x, stride_batch_y, (uint32_t)(ne20*ne21),
32353241
(uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3,
32363242
};
32373243
ggml_vk_sync_buffers(subctx);
3238-
ggml_vk_dispatch_pipeline(ctx, subctx, dmmv, { { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz * ne12 * ne13 }, { d_D, d_buf_offset, d_sz * ne22 * ne23} }, sizeof(vk_mat_vec_push_constants), &pc, { (uint32_t)ne01, (uint32_t)(ne12 * ne13), 1});
3244+
ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
3245+
{ { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz * ne12 * ne13 }, { d_D, d_buf_offset, d_sz * ne22 * ne23} },
3246+
sizeof(vk_mat_vec_push_constants), &pc, { groups_x, (uint32_t)(ne12 * ne13), groups_z });
32393247
}
32403248

32413249
static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3242-
VK_LOG_DEBUG("ggml_vk_mul_mat_p021_f16_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
3250+
VK_LOG_DEBUG("ggml_vk_mul_mat_p021_f16_f32(" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
32433251
std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
32443252
std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "),)");
32453253
GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1));
@@ -3740,6 +3748,16 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
37403748
stride_batch_y = src1->nb[0] / ggml_type_size(src1->type);
37413749
}
37423750

3751+
const uint32_t max_groups_x = ctx->device->properties.limits.maxComputeWorkGroupCount[0];
3752+
3753+
uint32_t groups_x = ne01;
3754+
uint32_t groups_z = 1;
3755+
3756+
if (ne01 > max_groups_x) {
3757+
groups_z = 64;
3758+
groups_x /= groups_z;
3759+
}
3760+
37433761
// compute
37443762
const vk_mat_vec_id_push_constants pc = {
37453763
(uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
@@ -3749,7 +3767,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
37493767
ggml_vk_sync_buffers(subctx);
37503768
ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
37513769
{ { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz * ne12 * ne13 }, { d_D, d_buf_offset, d_sz * ne22 * ne23}, { d_ids, ids_buf_offset, ids_sz } },
3752-
sizeof(vk_mat_vec_id_push_constants), &pc, { (uint32_t)ne01, (uint32_t)nei0, 1 });
3770+
sizeof(vk_mat_vec_id_push_constants), &pc, { groups_x, (uint32_t)nei0, groups_z });
37533771
}
37543772

37553773
static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
@@ -5606,7 +5624,6 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
56065624
}
56075625

56085626
extra->ctx_idx = ctx->compute_ctx->idx;
5609-
extra->backend_ctx = ctx;
56105627

56115628
#ifdef GGML_VULKAN_CHECK_RESULTS
56125629
// Force context reset on each node so that each tensor ends up in its own context

vulkan-shaders/mul_mat_vec.comp

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ layout (constant_id = 0) const uint BLOCK_SIZE = 32;
1313
shared FLOAT_TYPE tmp[BLOCK_SIZE];
1414

1515
void main() {
16-
const uint row = gl_WorkGroupID.x;
16+
const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
1717
const uint tid = gl_LocalInvocationID.x;
1818

1919
uint a_offset, b_offset, d_offset;

vulkan-shaders/mul_mat_vec_q2_k.comp

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
77
shared FLOAT_TYPE tmp[32];
88

99
void main() {
10-
const uint row = gl_WorkGroupID.x;
10+
const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
1111

1212
uint a_offset, b_offset, d_offset;
1313
get_offsets(a_offset, b_offset, d_offset);

vulkan-shaders/mul_mat_vec_q3_k.comp

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
77
shared FLOAT_TYPE tmp[32];
88

99
void main() {
10-
const uint row = gl_WorkGroupID.x;
10+
const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
1111

1212
uint a_offset, b_offset, d_offset;
1313
get_offsets(a_offset, b_offset, d_offset);

vulkan-shaders/mul_mat_vec_q4_k.comp

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
77
shared FLOAT_TYPE tmp[32];
88

99
void main() {
10-
const uint row = gl_WorkGroupID.x;
10+
const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
1111

1212
uint a_offset, b_offset, d_offset;
1313
get_offsets(a_offset, b_offset, d_offset);

vulkan-shaders/mul_mat_vec_q5_k.comp

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
77
shared FLOAT_TYPE tmp[32];
88

99
void main() {
10-
const uint row = gl_WorkGroupID.x;
10+
const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
1111

1212
uint a_offset, b_offset, d_offset;
1313
get_offsets(a_offset, b_offset, d_offset);

vulkan-shaders/mul_mat_vec_q6_k.comp

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
77
shared FLOAT_TYPE tmp[32];
88

99
void main() {
10-
const uint row = gl_WorkGroupID.x;
10+
const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
1111

1212
uint a_offset, b_offset, d_offset;
1313
get_offsets(a_offset, b_offset, d_offset);

0 commit comments

Comments
 (0)