@@ -616,6 +616,8 @@ static_assert(sizeof(block_iq4_xs) == sizeof(ggml_fp16_t) + sizeof(uint16_t) + Q
616
616
#define CUDA_UPSCALE_BLOCK_SIZE 256
617
617
#define CUDA_CONCAT_BLOCK_SIZE 256
618
618
#define CUDA_PAD_BLOCK_SIZE 256
619
+ #define CUDA_ARANGE_BLOCK_SIZE 256
620
+ #define CUDA_TIMESTEP_EMBEDDING_BLOCK_SIZE 256
619
621
#define CUDA_ACC_BLOCK_SIZE 256
620
622
#define CUDA_IM2COL_BLOCK_SIZE 256
621
623
#define CUDA_POOL2D_BLOCK_SIZE 256
@@ -990,17 +992,21 @@ static __global__ void concat_f32(const float * x,const float * y, float * dst,
990
992
nidx +
991
993
blockIdx .y * ne0 +
992
994
blockIdx .z * ne0 * gridDim .y ;
993
- dst[offset_dst] = x[offset_src];
995
+ dst[offset_dst] = x[offset_src];
994
996
} else {
995
997
int offset_src =
996
998
nidx +
997
999
blockIdx .y * ne0 +
998
1000
(blockIdx .z - ne02) * ne0 * gridDim .y ;
999
- dst[offset_dst] = y[offset_src];
1001
+ dst[offset_dst] = y[offset_src];
1000
1002
}
1001
1003
}
1002
1004
1003
- static __global__ void upscale_f32 (const float * x, float * dst, const int ne00, const int nb02, const int scale_factor) {
1005
+ static __global__ void upscale_f32 (const float * x, float * dst, const int ne00, const int ne00xne01, const int scale_factor) {
1006
+ // blockIdx.z: idx of ne02*ne03
1007
+ // blockIdx.y: idx of ne01*scale_factor, aka ne1
1008
+ // blockIDx.x: idx of ne00*scale_factor / BLOCK_SIZE
1009
+ // ne00xne01: ne00 * ne01
1004
1010
int ne0 = ne00 * scale_factor;
1005
1011
int nidx = threadIdx .x + blockIdx .x * blockDim .x ;
1006
1012
if (nidx >= ne0) {
@@ -1012,15 +1018,18 @@ static __global__ void upscale_f32(const float * x, float * dst, const int ne00,
1012
1018
int offset_src =
1013
1019
i00 +
1014
1020
i01 * ne00 +
1015
- blockIdx .z * nb02 ;
1021
+ blockIdx .z * ne00xne01 ;
1016
1022
int offset_dst =
1017
1023
nidx +
1018
1024
blockIdx .y * ne0 +
1019
1025
blockIdx .z * ne0 * gridDim .y ;
1020
1026
dst[offset_dst] = x[offset_src];
1021
1027
}
1022
1028
1023
- static __global__ void pad_f32 (const float * x, float * dst, const int ne0, const int ne00, const int ne01, const int ne02) {
1029
+ static __global__ void pad_f32 (const float * x, float * dst, const int ne0, const int ne00, const int ne01, const int ne02, const int ne03) {
1030
+ // blockIdx.z: idx of ne2*ne3, aka ne02*ne03
1031
+ // blockIdx.y: idx of ne1
1032
+ // blockIDx.x: idx of ne0 / BLOCK_SIZE
1024
1033
int nidx = threadIdx .x + blockIdx .x * blockDim .x ;
1025
1034
if (nidx >= ne0) {
1026
1035
return ;
@@ -1031,19 +1040,53 @@ static __global__ void pad_f32(const float * x, float * dst, const int ne0, cons
1031
1040
nidx +
1032
1041
blockIdx .y * ne0 +
1033
1042
blockIdx .z * ne0 * gridDim .y ;
1034
- if (nidx < ne00 && blockIdx .y < ne01 && blockIdx .z < ne02) {
1043
+ if (nidx < ne00 && blockIdx .y < ne01 && blockIdx .z < ne02*ne03 ) {
1035
1044
int offset_src =
1036
1045
nidx +
1037
1046
blockIdx .y * ne00 +
1038
1047
blockIdx .z * ne00 * ne01;
1039
- dst[offset_dst] = x[offset_src];
1048
+ dst[offset_dst] = x[offset_src];
1040
1049
} else {
1041
1050
dst[offset_dst] = 0 .0f ;
1042
1051
}
1043
1052
}
1044
1053
1054
+ static __global__ void arange_f32 (float * dst, const int ne0, const float start, const float step) {
1055
+ // blockIDx.x: idx of ne0 / BLOCK_SIZE
1056
+ int nidx = threadIdx .x + blockIdx .x * blockDim .x ;
1057
+ if (nidx >= ne0) {
1058
+ return ;
1059
+ }
1060
+ dst[nidx] = start + step * nidx;
1061
+ }
1062
+
1063
+ static __global__ void timestep_embedding_f32 (const float * timesteps, float * dst, const int nb1, const int dim, const int max_period) {
1064
+ // blockIDx.y: idx of timesteps->ne[0]
1065
+ // blockIDx.x: idx of ((dim + 1) / 2) / BLOCK_SIZE
1066
+ int i = blockIdx .y ;
1067
+ int j = threadIdx .x + blockIdx .x * blockDim .x ;
1068
+ float * embed_data = (float *)((char *)dst + i*nb1);
1069
+
1070
+ if (dim % 2 != 0 && j == ((dim + 1 ) / 2 )) {
1071
+ embed_data[dim] = 0 .f ;
1072
+ }
1073
+
1074
+ int half = dim / 2 ;
1075
+ if (j >= half) {
1076
+ return ;
1077
+ }
1078
+
1079
+ float timestep = timesteps[i];
1080
+ float freq = (float )expf (-logf (max_period) * j / half);
1081
+ float arg = timestep * freq;
1082
+ embed_data[j] = cosf (arg);
1083
+ embed_data[j + half] = sinf (arg);
1084
+ }
1085
+
1045
1086
template <int block_size>
1046
1087
static __global__ void group_norm_f32 (const float * x, float * dst, const int group_size, const int ne_elements, const float eps) {
1088
+ // blockIdx.x: num_groups idx
1089
+ // threadIdx.x: block_size idx
1047
1090
int start = blockIdx .x * group_size;
1048
1091
int end = start + group_size;
1049
1092
@@ -6448,25 +6491,25 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
6448
6491
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
6449
6492
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
6450
6493
const int nb12, const int nb13) {
6451
- const int i = blockDim .x *blockIdx .x + threadIdx .x ;
6494
+ const int64_t i = blockDim .x *blockIdx .x + threadIdx .x ;
6452
6495
6453
6496
if (i >= ne) {
6454
6497
return ;
6455
6498
}
6456
6499
6457
6500
// determine indices i03/i13, i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor
6458
6501
// then combine those indices with the corresponding byte offsets to get the total offsets
6459
- const int i03 = i/(ne00 * ne01 * ne02);
6460
- const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
6461
- const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
6462
- const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
6463
- const int x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
6464
-
6465
- const int i13 = i/(ne10 * ne11 * ne12);
6466
- const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
6467
- const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
6468
- const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
6469
- const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13 * nb13;
6502
+ const int64_t i03 = i/(ne00 * ne01 * ne02);
6503
+ const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
6504
+ const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
6505
+ const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
6506
+ const int64_t x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
6507
+
6508
+ const int64_t i13 = i/(ne10 * ne11 * ne12);
6509
+ const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
6510
+ const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
6511
+ const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
6512
+ const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13 * nb13;
6470
6513
6471
6514
cpy_1 (cx + x_offset, cdst + dst_offset);
6472
6515
}
@@ -6956,23 +6999,23 @@ static __global__ void clamp_f32(const float * x, float * dst, const float min,
6956
6999
6957
7000
template <typename T>
6958
7001
static __global__ void im2col_kernel (
6959
- const float * x, T * dst, int batch_offset,
6960
- int offset_delta, int IC, int IW, int IH, int OH, int OW, int KW, int KH, int pelements, int CHW,
7002
+ const float * x, T * dst, int64_t batch_offset,
7003
+ int64_t offset_delta, int64_t IC, int64_t IW, int64_t IH, int64_t OH, int64_t OW, int64_t KW, int64_t KH, int64_t pelements, int64_t CHW,
6961
7004
int s0, int s1, int p0, int p1, int d0, int d1) {
6962
- const int i = threadIdx .x + blockIdx .x * blockDim .x ;
7005
+ const int64_t i = threadIdx .x + blockIdx .x * blockDim .x ;
6963
7006
if (i >= pelements) {
6964
7007
return ;
6965
7008
}
6966
7009
6967
- const int ksize = OW * (KH > 1 ? KW : 1 );
6968
- const int kx = i / ksize;
6969
- const int kd = kx * ksize;
6970
- const int ky = (i - kd) / OW;
6971
- const int ix = i % OW;
7010
+ const int64_t ksize = OW * (KH > 1 ? KW : 1 );
7011
+ const int64_t kx = i / ksize;
7012
+ const int64_t kd = kx * ksize;
7013
+ const int64_t ky = (i - kd) / OW;
7014
+ const int64_t ix = i % OW;
6972
7015
6973
- const int oh = blockIdx .y ;
6974
- const int batch = blockIdx .z / IC;
6975
- const int ic = blockIdx .z % IC;
7016
+ const int64_t oh = blockIdx .y ;
7017
+ const int64_t batch = blockIdx .z / IC;
7018
+ const int64_t ic = blockIdx .z % IC;
6976
7019
6977
7020
const int64_t iiw = ix * s0 + kx * d0 - p0;
6978
7021
const int64_t iih = oh * s1 + ky * d1 - p1;
@@ -7298,19 +7341,33 @@ static void concat_f32_cuda(const float * x, const float * y, float * dst, const
7298
7341
concat_f32<<<gridDim , CUDA_CONCAT_BLOCK_SIZE, 0 , stream>>> (x, y, dst, ne0, ne02);
7299
7342
}
7300
7343
7301
- static void upscale_f32_cuda (const float * x, float * dst, const int ne00, const int ne01, const int ne02, const int scale_factor, cudaStream_t stream) {
7344
+ static void upscale_f32_cuda (const float * x, float * dst, const int ne00, const int ne01, const int ne02, const int ne03,
7345
+ const int scale_factor, cudaStream_t stream) {
7302
7346
int ne0 = (ne00 * scale_factor);
7303
7347
int num_blocks = (ne0 + CUDA_UPSCALE_BLOCK_SIZE - 1 ) / CUDA_UPSCALE_BLOCK_SIZE;
7304
- dim3 gridDim (num_blocks, (ne01 * scale_factor), ne02);
7348
+ dim3 gridDim (num_blocks, (ne01 * scale_factor), ne02*ne03 );
7305
7349
upscale_f32<<<gridDim , CUDA_UPSCALE_BLOCK_SIZE, 0 , stream>>> (x, dst, ne00, ne00 * ne01, scale_factor);
7306
7350
}
7307
7351
7308
7352
static void pad_f32_cuda (const float * x, float * dst,
7309
- const int ne00, const int ne01, const int ne02,
7310
- const int ne0, const int ne1, const int ne2, cudaStream_t stream) {
7353
+ const int ne00, const int ne01, const int ne02, const int ne03,
7354
+ const int ne0, const int ne1, const int ne2, const int ne3, cudaStream_t stream) {
7311
7355
int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1 ) / CUDA_PAD_BLOCK_SIZE;
7312
- dim3 gridDim (num_blocks, ne1, ne2);
7313
- pad_f32<<<gridDim , CUDA_PAD_BLOCK_SIZE, 0 , stream>>> (x, dst, ne0, ne00, ne01, ne02);
7356
+ dim3 gridDim (num_blocks, ne1, ne2*ne3);
7357
+ pad_f32<<<gridDim , CUDA_PAD_BLOCK_SIZE, 0 , stream>>> (x, dst, ne0, ne00, ne01, ne02, ne03);
7358
+ }
7359
+
7360
+ static void arange_f32_cuda (float * dst, const int ne0, const float start, const float step, cudaStream_t stream) {
7361
+ int num_blocks = (ne0 + CUDA_ARANGE_BLOCK_SIZE - 1 ) / CUDA_ARANGE_BLOCK_SIZE;
7362
+ arange_f32<<<num_blocks, CUDA_ARANGE_BLOCK_SIZE, 0 , stream>>> (dst, ne0, start, step);
7363
+ }
7364
+
7365
+ static void timestep_embedding_f32_cuda (const float * x, float * dst, const int ne00, const int nb1,
7366
+ const int dim, const int max_period, cudaStream_t stream) {
7367
+ int half_ceil = (dim + 1 ) / 2 ;
7368
+ int num_blocks = (half_ceil + CUDA_TIMESTEP_EMBEDDING_BLOCK_SIZE - 1 ) / CUDA_TIMESTEP_EMBEDDING_BLOCK_SIZE;
7369
+ dim3 gridDim (num_blocks, ne00, 1 );
7370
+ timestep_embedding_f32<<<gridDim , CUDA_TIMESTEP_EMBEDDING_BLOCK_SIZE, 0 , stream>>> (x, dst, nb1, dim, max_period);
7314
7371
}
7315
7372
7316
7373
static void rms_norm_f32_cuda (const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
@@ -8443,8 +8500,8 @@ static void soft_max_f32_cuda(const float * x, const float * mask, const float *
8443
8500
8444
8501
template <typename T>
8445
8502
static void im2col_cuda (const float * x, T* dst,
8446
- int IW, int IH, int OW, int OH, int KW, int KH, int IC,
8447
- int batch, int batch_offset, int offset_delta,
8503
+ int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, int64_t KH, int64_t IC,
8504
+ int64_t batch, int64_t batch_offset, int64_t offset_delta,
8448
8505
int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
8449
8506
const int parallel_elements = OW * KW * KH;
8450
8507
const int num_blocks = (parallel_elements + CUDA_IM2COL_BLOCK_SIZE - 1 ) / CUDA_IM2COL_BLOCK_SIZE;
@@ -9123,7 +9180,7 @@ static void ggml_cuda_op_group_norm(
9123
9180
9124
9181
int num_groups = dst->op_params [0 ];
9125
9182
int group_size = src0->ne [0 ] * src0->ne [1 ] * ((src0->ne [2 ] + num_groups - 1 ) / num_groups);
9126
- group_norm_f32_cuda (src0_dd, dst_dd, num_groups, group_size, src0-> ne [ 0 ] * src0->ne [1 ] * src0-> ne [ 2 ] , main_stream);
9183
+ group_norm_f32_cuda (src0_dd, dst_dd, num_groups * src0->ne [3 ], group_size, ggml_nelements ( src0) , main_stream);
9127
9184
9128
9185
(void ) src1;
9129
9186
(void ) dst;
@@ -9156,7 +9213,7 @@ static void ggml_cuda_op_upscale(
9156
9213
9157
9214
const int scale_factor = dst->op_params [0 ];
9158
9215
9159
- upscale_f32_cuda (src0_dd, dst_dd, src0->ne [0 ], src0->ne [1 ], src0->ne [2 ], scale_factor, main_stream);
9216
+ upscale_f32_cuda (src0_dd, dst_dd, src0->ne [0 ], src0->ne [1 ], src0->ne [2 ], src0-> ne [ 3 ], scale_factor, main_stream);
9160
9217
9161
9218
(void ) src1;
9162
9219
(void ) dst;
@@ -9172,8 +9229,49 @@ static void ggml_cuda_op_pad(
9172
9229
GGML_ASSERT (src0->ne [3 ] == 1 && dst->ne [3 ] == 1 ); // just 3D tensors
9173
9230
9174
9231
pad_f32_cuda (src0_dd, dst_dd,
9175
- src0->ne [0 ], src0->ne [1 ], src0->ne [2 ],
9176
- dst->ne [0 ], dst->ne [1 ], dst->ne [2 ], main_stream);
9232
+ src0->ne [0 ], src0->ne [1 ], src0->ne [2 ], src0->ne [3 ],
9233
+ dst->ne [0 ], dst->ne [1 ], dst->ne [2 ], dst->ne [3 ], main_stream);
9234
+
9235
+ (void ) src1;
9236
+ (void ) dst;
9237
+ (void ) src1_dd;
9238
+ }
9239
+
9240
+ static void ggml_cuda_op_arange (
9241
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
9242
+ const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
9243
+
9244
+ GGML_ASSERT (dst->type == GGML_TYPE_F32);
9245
+
9246
+ float start;
9247
+ float stop;
9248
+ float step;
9249
+ memcpy (&start, (float *)dst->op_params + 0 , sizeof (float ));
9250
+ memcpy (&stop, (float *)dst->op_params + 1 , sizeof (float ));
9251
+ memcpy (&step, (float *)dst->op_params + 2 , sizeof (float ));
9252
+
9253
+ int64_t steps = (int64_t )ceil ((stop - start) / step);
9254
+ GGML_ASSERT (ggml_nelements (dst) == steps);
9255
+
9256
+ arange_f32_cuda (dst_dd, dst->ne [0 ], start, step, main_stream);
9257
+
9258
+ (void ) src0;
9259
+ (void ) src1;
9260
+ (void ) src0_dd;
9261
+ (void ) src1_dd;
9262
+ }
9263
+
9264
+ static void ggml_cuda_op_timestep_embedding (
9265
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
9266
+ const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
9267
+
9268
+ GGML_ASSERT (src0->type == GGML_TYPE_F32);
9269
+ GGML_ASSERT (dst->type == GGML_TYPE_F32);
9270
+
9271
+ const int dim = dst->op_params [0 ];
9272
+ const int max_period = dst->op_params [1 ];
9273
+
9274
+ timestep_embedding_f32_cuda (src0_dd, dst_dd, src0->ne [0 ], dst->nb [1 ], dim, max_period, main_stream);
9177
9275
9178
9276
(void ) src1;
9179
9277
(void ) dst;
@@ -10458,6 +10556,45 @@ static void ggml_cuda_pad(const ggml_tensor * src0, const ggml_tensor * src1, gg
10458
10556
ggml_cuda_op_flatten (src0, src1, dst, ggml_cuda_op_pad);
10459
10557
}
10460
10558
10559
+ static void ggml_cuda_arange (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
10560
+ ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra ;
10561
+
10562
+ const bool dst_on_device = dst->backend == GGML_BACKEND_TYPE_GPU;
10563
+
10564
+ // dd = data device
10565
+ float * src0_ddf = nullptr ;
10566
+ float * src1_ddf = nullptr ;
10567
+ float * dst_ddf = nullptr ;
10568
+
10569
+ cuda_pool_alloc<float > dst_f;
10570
+
10571
+ ggml_cuda_set_device (g_main_device);
10572
+ cudaStream_t main_stream = g_cudaStreams[g_main_device][0 ];
10573
+
10574
+ if (dst_on_device) {
10575
+ dst_ddf = (float *) dst_extra->data_device [g_main_device];
10576
+ } else {
10577
+ dst_ddf = dst_f.alloc (ggml_nelements (dst));
10578
+ }
10579
+
10580
+ // do the computation
10581
+ ggml_cuda_op_arange (src0, src1, dst, src0_ddf, src1_ddf, dst_ddf, main_stream);
10582
+ CUDA_CHECK (cudaGetLastError ());
10583
+
10584
+ // copy dst to host if necessary
10585
+ if (!dst_on_device) {
10586
+ CUDA_CHECK (cudaMemcpyAsync (dst->data , dst_ddf, ggml_nbytes (dst), cudaMemcpyDeviceToHost, main_stream));
10587
+ }
10588
+
10589
+ if (dst->backend == GGML_BACKEND_TYPE_CPU) {
10590
+ CUDA_CHECK (cudaDeviceSynchronize ());
10591
+ }
10592
+ }
10593
+
10594
+ static void ggml_cuda_timestep_embedding (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
10595
+ ggml_cuda_op_flatten (src0, src1, dst, ggml_cuda_op_timestep_embedding);
10596
+ }
10597
+
10461
10598
static void ggml_cuda_rms_norm (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
10462
10599
ggml_cuda_op_flatten (src0, src1, dst, ggml_cuda_op_rms_norm);
10463
10600
}
@@ -11358,6 +11495,12 @@ GGML_CALL bool ggml_cuda_compute_forward(struct ggml_compute_params * params, st
11358
11495
case GGML_OP_PAD:
11359
11496
func = ggml_cuda_pad;
11360
11497
break ;
11498
+ case GGML_OP_ARANGE:
11499
+ func = ggml_cuda_arange;
11500
+ break ;
11501
+ case GGML_OP_TIMESTEP_EMBEDDING:
11502
+ func = ggml_cuda_timestep_embedding;
11503
+ break ;
11361
11504
case GGML_OP_LEAKY_RELU:
11362
11505
func = ggml_cuda_leaky_relu;
11363
11506
break ;
@@ -12253,6 +12396,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
12253
12396
case GGML_OP_GROUP_NORM:
12254
12397
case GGML_OP_UPSCALE:
12255
12398
case GGML_OP_PAD:
12399
+ case GGML_OP_ARANGE:
12400
+ case GGML_OP_TIMESTEP_EMBEDDING:
12256
12401
case GGML_OP_LEAKY_RELU:
12257
12402
return true ;
12258
12403
default :
0 commit comments