Skip to content

Commit ba94c9d

Browse files
committed
ggml : parallelize ggml_exp
This results in 8% faster token generation for Mamba-130M.
1 parent e67c420 commit ba94c9d

File tree

1 file changed

+19
-14
lines changed

1 file changed

+19
-14
lines changed

ggml.c

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8414,27 +8414,32 @@ static void ggml_compute_forward_exp_f32(
84148414
const struct ggml_compute_params * params,
84158415
const struct ggml_tensor * src0,
84168416
struct ggml_tensor * dst) {
8417-
GGML_ASSERT(params->ith == 0);
8417+
GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0));
8418+
GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst));
84188419
GGML_ASSERT(ggml_are_same_shape(src0, dst));
84198420

84208421
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
84218422
return;
84228423
}
84238424

8424-
GGML_ASSERT( dst->nb[0] == sizeof(float));
8425-
GGML_ASSERT(src0->nb[0] == sizeof(float));
8425+
const int ith = params->ith;
8426+
const int nth = params->nth;
84268427

8427-
GGML_TENSOR_UNARY_OP_LOCALS
8428+
const int nc = src0->ne[0];
8429+
const int nr = ggml_nrows(src0);
84288430

8429-
for (int64_t i3 = 0; i3 < ne03; i3++) {
8430-
for (int64_t i2 = 0; i2 < ne02; i2++) {
8431-
for (int64_t i1 = 0; i1 < ne01; i1++) {
8432-
float * src_row = (float *) ((char *) src0->data + i1*nb01 + i2*nb02 + i3*nb03);
8433-
float * dst_row = (float *) ((char *) dst->data + i1*nb1 + i2*nb2 + i3*nb3);
8434-
ggml_vec_exp_f32(ne00, dst_row, src_row);
8435-
}
8436-
}
8437-
}
8431+
// rows per thread
8432+
const int dr = (nr + nth - 1)/nth;
8433+
8434+
// row range for this thread
8435+
const int ir0 = dr*ith;
8436+
const int ir1 = MIN(ir0 + dr, nr);
8437+
8438+
for (int i1 = ir0; i1 < ir1; i1++) {
8439+
ggml_vec_exp_f32(nc,
8440+
(float *) ((char *) dst->data + i1*( dst->nb[1])),
8441+
(float *) ((char *) src0->data + i1*(src0->nb[1])));
8442+
};
84388443
}
84398444

84408445
static void ggml_compute_forward_exp(
@@ -16850,13 +16855,13 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
1685016855
case GGML_OP_ADD:
1685116856
case GGML_OP_ADD1:
1685216857
case GGML_OP_ACC:
16858+
case GGML_OP_EXP:
1685316859
{
1685416860
n_tasks = n_threads;
1685516861
} break;
1685616862
case GGML_OP_SUB:
1685716863
case GGML_OP_SQR:
1685816864
case GGML_OP_SQRT:
16859-
case GGML_OP_EXP:
1686016865
case GGML_OP_LOG:
1686116866
case GGML_OP_SUM:
1686216867
case GGML_OP_SUM_ROWS:

0 commit comments

Comments
 (0)