@@ -8414,27 +8414,32 @@ static void ggml_compute_forward_exp_f32(
8414
8414
const struct ggml_compute_params * params,
8415
8415
const struct ggml_tensor * src0,
8416
8416
struct ggml_tensor * dst) {
8417
- GGML_ASSERT(params->ith == 0);
8417
+ GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0));
8418
+ GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst));
8418
8419
GGML_ASSERT(ggml_are_same_shape(src0, dst));
8419
8420
8420
8421
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
8421
8422
return;
8422
8423
}
8423
8424
8424
- GGML_ASSERT( dst->nb[0] == sizeof(float)) ;
8425
- GGML_ASSERT(src0->nb[0] == sizeof(float)) ;
8425
+ const int ith = params->ith ;
8426
+ const int nth = params->nth ;
8426
8427
8427
- GGML_TENSOR_UNARY_OP_LOCALS
8428
+ const int nc = src0->ne[0];
8429
+ const int nr = ggml_nrows(src0);
8428
8430
8429
- for (int64_t i3 = 0; i3 < ne03; i3++) {
8430
- for (int64_t i2 = 0; i2 < ne02; i2++) {
8431
- for (int64_t i1 = 0; i1 < ne01; i1++) {
8432
- float * src_row = (float *) ((char *) src0->data + i1*nb01 + i2*nb02 + i3*nb03);
8433
- float * dst_row = (float *) ((char *) dst->data + i1*nb1 + i2*nb2 + i3*nb3);
8434
- ggml_vec_exp_f32(ne00, dst_row, src_row);
8435
- }
8436
- }
8437
- }
8431
+ // rows per thread
8432
+ const int dr = (nr + nth - 1)/nth;
8433
+
8434
+ // row range for this thread
8435
+ const int ir0 = dr*ith;
8436
+ const int ir1 = MIN(ir0 + dr, nr);
8437
+
8438
+ for (int i1 = ir0; i1 < ir1; i1++) {
8439
+ ggml_vec_exp_f32(nc,
8440
+ (float *) ((char *) dst->data + i1*( dst->nb[1])),
8441
+ (float *) ((char *) src0->data + i1*(src0->nb[1])));
8442
+ };
8438
8443
}
8439
8444
8440
8445
static void ggml_compute_forward_exp(
@@ -16850,13 +16855,13 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
16850
16855
case GGML_OP_ADD:
16851
16856
case GGML_OP_ADD1:
16852
16857
case GGML_OP_ACC:
16858
+ case GGML_OP_EXP:
16853
16859
{
16854
16860
n_tasks = n_threads;
16855
16861
} break;
16856
16862
case GGML_OP_SUB:
16857
16863
case GGML_OP_SQR:
16858
16864
case GGML_OP_SQRT:
16859
- case GGML_OP_EXP:
16860
16865
case GGML_OP_LOG:
16861
16866
case GGML_OP_SUM:
16862
16867
case GGML_OP_SUM_ROWS:
0 commit comments