Skip to content

Commit e3043a4

Browse files
leejetggerganovslaren
authored andcommitted
add some new ops, fix some operators and add batch operations to certain operators. (ggml/747)
* cuda: fix group_norm * cuda: add batch inference support for ggml_pad/ggml_upscale * add ggml_arrange * add ggml_timestep_embedding * update ggml_arange/ggml_timestep_embedding tests * cuda: fix im2col * add ggml_arange/ggml_timestep_embbeding support for metal backend * fix some bugs * fix some bugs * Update ggml.h Co-authored-by: Georgi Gerganov <[email protected]> * Update ggml-cuda.cu Co-authored-by: Georgi Gerganov <[email protected]> * Update ggml-metal.m Co-authored-by: Georgi Gerganov <[email protected]> * Update ggml-metal.m Co-authored-by: Georgi Gerganov <[email protected]> * Update ggml-metal.metal Co-authored-by: Georgi Gerganov <[email protected]> * modify according to the review comments * ggml : fix compile warnings + code style * ggml : normalize compute_forward calls + fix seg fault in debug * minor --------- Co-authored-by: Georgi Gerganov <[email protected]> Co-authored-by: slaren <[email protected]>
1 parent 50973c7 commit e3043a4

File tree

6 files changed

+551
-53
lines changed

6 files changed

+551
-53
lines changed

ggml-cuda.cu

Lines changed: 187 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -616,6 +616,8 @@ static_assert(sizeof(block_iq4_xs) == sizeof(ggml_fp16_t) + sizeof(uint16_t) + Q
616616
#define CUDA_UPSCALE_BLOCK_SIZE 256
617617
#define CUDA_CONCAT_BLOCK_SIZE 256
618618
#define CUDA_PAD_BLOCK_SIZE 256
619+
#define CUDA_ARANGE_BLOCK_SIZE 256
620+
#define CUDA_TIMESTEP_EMBEDDING_BLOCK_SIZE 256
619621
#define CUDA_ACC_BLOCK_SIZE 256
620622
#define CUDA_IM2COL_BLOCK_SIZE 256
621623
#define CUDA_POOL2D_BLOCK_SIZE 256
@@ -990,17 +992,21 @@ static __global__ void concat_f32(const float * x,const float * y, float * dst,
990992
nidx +
991993
blockIdx.y * ne0 +
992994
blockIdx.z * ne0 * gridDim.y;
993-
dst[offset_dst] = x[offset_src];
995+
dst[offset_dst] = x[offset_src];
994996
} else {
995997
int offset_src =
996998
nidx +
997999
blockIdx.y * ne0 +
9981000
(blockIdx.z - ne02) * ne0 * gridDim.y;
999-
dst[offset_dst] = y[offset_src];
1001+
dst[offset_dst] = y[offset_src];
10001002
}
10011003
}
10021004

1003-
static __global__ void upscale_f32(const float * x, float * dst, const int ne00, const int nb02, const int scale_factor) {
1005+
static __global__ void upscale_f32(const float * x, float * dst, const int ne00, const int ne00xne01, const int scale_factor) {
1006+
// blockIdx.z: idx of ne02*ne03
1007+
// blockIdx.y: idx of ne01*scale_factor, aka ne1
1008+
// blockIDx.x: idx of ne00*scale_factor / BLOCK_SIZE
1009+
// ne00xne01: ne00 * ne01
10041010
int ne0 = ne00 * scale_factor;
10051011
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
10061012
if (nidx >= ne0) {
@@ -1012,15 +1018,18 @@ static __global__ void upscale_f32(const float * x, float * dst, const int ne00,
10121018
int offset_src =
10131019
i00 +
10141020
i01 * ne00 +
1015-
blockIdx.z * nb02;
1021+
blockIdx.z * ne00xne01;
10161022
int offset_dst =
10171023
nidx +
10181024
blockIdx.y * ne0 +
10191025
blockIdx.z * ne0 * gridDim.y;
10201026
dst[offset_dst] = x[offset_src];
10211027
}
10221028

1023-
static __global__ void pad_f32(const float * x, float * dst, const int ne0, const int ne00, const int ne01, const int ne02) {
1029+
static __global__ void pad_f32(const float * x, float * dst, const int ne0, const int ne00, const int ne01, const int ne02, const int ne03) {
1030+
// blockIdx.z: idx of ne2*ne3, aka ne02*ne03
1031+
// blockIdx.y: idx of ne1
1032+
// blockIDx.x: idx of ne0 / BLOCK_SIZE
10241033
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
10251034
if (nidx >= ne0) {
10261035
return;
@@ -1031,19 +1040,53 @@ static __global__ void pad_f32(const float * x, float * dst, const int ne0, cons
10311040
nidx +
10321041
blockIdx.y * ne0 +
10331042
blockIdx.z * ne0 * gridDim.y;
1034-
if (nidx < ne00 && blockIdx.y < ne01 && blockIdx.z < ne02) {
1043+
if (nidx < ne00 && blockIdx.y < ne01 && blockIdx.z < ne02*ne03) {
10351044
int offset_src =
10361045
nidx +
10371046
blockIdx.y * ne00 +
10381047
blockIdx.z * ne00 * ne01;
1039-
dst[offset_dst] = x[offset_src];
1048+
dst[offset_dst] = x[offset_src];
10401049
} else {
10411050
dst[offset_dst] = 0.0f;
10421051
}
10431052
}
10441053

1054+
static __global__ void arange_f32(float * dst, const int ne0, const float start, const float step) {
1055+
// blockIDx.x: idx of ne0 / BLOCK_SIZE
1056+
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
1057+
if (nidx >= ne0) {
1058+
return;
1059+
}
1060+
dst[nidx] = start + step * nidx;
1061+
}
1062+
1063+
static __global__ void timestep_embedding_f32(const float * timesteps, float * dst, const int nb1, const int dim, const int max_period) {
1064+
// blockIDx.y: idx of timesteps->ne[0]
1065+
// blockIDx.x: idx of ((dim + 1) / 2) / BLOCK_SIZE
1066+
int i = blockIdx.y;
1067+
int j = threadIdx.x + blockIdx.x * blockDim.x;
1068+
float * embed_data = (float *)((char *)dst + i*nb1);
1069+
1070+
if (dim % 2 != 0 && j == ((dim + 1) / 2)) {
1071+
embed_data[dim] = 0.f;
1072+
}
1073+
1074+
int half = dim / 2;
1075+
if (j >= half) {
1076+
return;
1077+
}
1078+
1079+
float timestep = timesteps[i];
1080+
float freq = (float)expf(-logf(max_period) * j / half);
1081+
float arg = timestep * freq;
1082+
embed_data[j] = cosf(arg);
1083+
embed_data[j + half] = sinf(arg);
1084+
}
1085+
10451086
template <int block_size>
10461087
static __global__ void group_norm_f32(const float * x, float * dst, const int group_size, const int ne_elements, const float eps) {
1088+
// blockIdx.x: num_groups idx
1089+
// threadIdx.x: block_size idx
10471090
int start = blockIdx.x * group_size;
10481091
int end = start + group_size;
10491092

@@ -6448,25 +6491,25 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
64486491
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
64496492
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
64506493
const int nb12, const int nb13) {
6451-
const int i = blockDim.x*blockIdx.x + threadIdx.x;
6494+
const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
64526495

64536496
if (i >= ne) {
64546497
return;
64556498
}
64566499

64576500
// determine indices i03/i13, i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor
64586501
// then combine those indices with the corresponding byte offsets to get the total offsets
6459-
const int i03 = i/(ne00 * ne01 * ne02);
6460-
const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
6461-
const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
6462-
const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
6463-
const int x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
6464-
6465-
const int i13 = i/(ne10 * ne11 * ne12);
6466-
const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
6467-
const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
6468-
const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
6469-
const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13 * nb13;
6502+
const int64_t i03 = i/(ne00 * ne01 * ne02);
6503+
const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
6504+
const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
6505+
const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
6506+
const int64_t x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
6507+
6508+
const int64_t i13 = i/(ne10 * ne11 * ne12);
6509+
const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
6510+
const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
6511+
const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
6512+
const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13 * nb13;
64706513

64716514
cpy_1(cx + x_offset, cdst + dst_offset);
64726515
}
@@ -6956,23 +6999,23 @@ static __global__ void clamp_f32(const float * x, float * dst, const float min,
69566999

69577000
template <typename T>
69587001
static __global__ void im2col_kernel(
6959-
const float * x, T * dst, int batch_offset,
6960-
int offset_delta, int IC, int IW, int IH, int OH, int OW, int KW, int KH, int pelements, int CHW,
7002+
const float * x, T * dst, int64_t batch_offset,
7003+
int64_t offset_delta, int64_t IC, int64_t IW, int64_t IH, int64_t OH, int64_t OW, int64_t KW, int64_t KH, int64_t pelements, int64_t CHW,
69617004
int s0, int s1, int p0, int p1, int d0, int d1) {
6962-
const int i = threadIdx.x + blockIdx.x * blockDim.x;
7005+
const int64_t i = threadIdx.x + blockIdx.x * blockDim.x;
69637006
if (i >= pelements) {
69647007
return;
69657008
}
69667009

6967-
const int ksize = OW * (KH > 1 ? KW : 1);
6968-
const int kx = i / ksize;
6969-
const int kd = kx * ksize;
6970-
const int ky = (i - kd) / OW;
6971-
const int ix = i % OW;
7010+
const int64_t ksize = OW * (KH > 1 ? KW : 1);
7011+
const int64_t kx = i / ksize;
7012+
const int64_t kd = kx * ksize;
7013+
const int64_t ky = (i - kd) / OW;
7014+
const int64_t ix = i % OW;
69727015

6973-
const int oh = blockIdx.y;
6974-
const int batch = blockIdx.z / IC;
6975-
const int ic = blockIdx.z % IC;
7016+
const int64_t oh = blockIdx.y;
7017+
const int64_t batch = blockIdx.z / IC;
7018+
const int64_t ic = blockIdx.z % IC;
69767019

69777020
const int64_t iiw = ix * s0 + kx * d0 - p0;
69787021
const int64_t iih = oh * s1 + ky * d1 - p1;
@@ -7298,19 +7341,33 @@ static void concat_f32_cuda(const float * x, const float * y, float * dst, const
72987341
concat_f32<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne02);
72997342
}
73007343

7301-
static void upscale_f32_cuda(const float * x, float * dst, const int ne00, const int ne01, const int ne02, const int scale_factor, cudaStream_t stream) {
7344+
static void upscale_f32_cuda(const float * x, float * dst, const int ne00, const int ne01, const int ne02, const int ne03,
7345+
const int scale_factor, cudaStream_t stream) {
73027346
int ne0 = (ne00 * scale_factor);
73037347
int num_blocks = (ne0 + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE;
7304-
dim3 gridDim(num_blocks, (ne01 * scale_factor), ne02);
7348+
dim3 gridDim(num_blocks, (ne01 * scale_factor), ne02*ne03);
73057349
upscale_f32<<<gridDim, CUDA_UPSCALE_BLOCK_SIZE, 0, stream>>>(x, dst, ne00, ne00 * ne01, scale_factor);
73067350
}
73077351

73087352
static void pad_f32_cuda(const float * x, float * dst,
7309-
const int ne00, const int ne01, const int ne02,
7310-
const int ne0, const int ne1, const int ne2, cudaStream_t stream) {
7353+
const int ne00, const int ne01, const int ne02, const int ne03,
7354+
const int ne0, const int ne1, const int ne2, const int ne3, cudaStream_t stream) {
73117355
int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE;
7312-
dim3 gridDim(num_blocks, ne1, ne2);
7313-
pad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(x, dst, ne0, ne00, ne01, ne02);
7356+
dim3 gridDim(num_blocks, ne1, ne2*ne3);
7357+
pad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(x, dst, ne0, ne00, ne01, ne02, ne03);
7358+
}
7359+
7360+
static void arange_f32_cuda(float * dst, const int ne0, const float start, const float step, cudaStream_t stream) {
7361+
int num_blocks = (ne0 + CUDA_ARANGE_BLOCK_SIZE - 1) / CUDA_ARANGE_BLOCK_SIZE;
7362+
arange_f32<<<num_blocks, CUDA_ARANGE_BLOCK_SIZE, 0, stream>>>(dst, ne0, start, step);
7363+
}
7364+
7365+
static void timestep_embedding_f32_cuda(const float * x, float * dst, const int ne00, const int nb1,
7366+
const int dim, const int max_period, cudaStream_t stream) {
7367+
int half_ceil = (dim + 1) / 2;
7368+
int num_blocks = (half_ceil + CUDA_TIMESTEP_EMBEDDING_BLOCK_SIZE - 1) / CUDA_TIMESTEP_EMBEDDING_BLOCK_SIZE;
7369+
dim3 gridDim(num_blocks, ne00, 1);
7370+
timestep_embedding_f32<<<gridDim, CUDA_TIMESTEP_EMBEDDING_BLOCK_SIZE, 0, stream>>>(x, dst, nb1, dim, max_period);
73147371
}
73157372

73167373
static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
@@ -8443,8 +8500,8 @@ static void soft_max_f32_cuda(const float * x, const float * mask, const float *
84438500

84448501
template <typename T>
84458502
static void im2col_cuda(const float* x, T* dst,
8446-
int IW, int IH, int OW, int OH, int KW, int KH, int IC,
8447-
int batch, int batch_offset, int offset_delta,
8503+
int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, int64_t KH, int64_t IC,
8504+
int64_t batch, int64_t batch_offset, int64_t offset_delta,
84488505
int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
84498506
const int parallel_elements = OW * KW * KH;
84508507
const int num_blocks = (parallel_elements + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE;
@@ -9123,7 +9180,7 @@ static void ggml_cuda_op_group_norm(
91239180

91249181
int num_groups = dst->op_params[0];
91259182
int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
9126-
group_norm_f32_cuda(src0_dd, dst_dd, num_groups, group_size, src0->ne[0] * src0->ne[1] * src0->ne[2], main_stream);
9183+
group_norm_f32_cuda(src0_dd, dst_dd, num_groups * src0->ne[3], group_size, ggml_nelements(src0), main_stream);
91279184

91289185
(void) src1;
91299186
(void) dst;
@@ -9156,7 +9213,7 @@ static void ggml_cuda_op_upscale(
91569213

91579214
const int scale_factor = dst->op_params[0];
91589215

9159-
upscale_f32_cuda(src0_dd, dst_dd, src0->ne[0], src0->ne[1], src0->ne[2], scale_factor, main_stream);
9216+
upscale_f32_cuda(src0_dd, dst_dd, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], scale_factor, main_stream);
91609217

91619218
(void) src1;
91629219
(void) dst;
@@ -9172,8 +9229,49 @@ static void ggml_cuda_op_pad(
91729229
GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
91739230

91749231
pad_f32_cuda(src0_dd, dst_dd,
9175-
src0->ne[0], src0->ne[1], src0->ne[2],
9176-
dst->ne[0], dst->ne[1], dst->ne[2], main_stream);
9232+
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
9233+
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], main_stream);
9234+
9235+
(void) src1;
9236+
(void) dst;
9237+
(void) src1_dd;
9238+
}
9239+
9240+
static void ggml_cuda_op_arange(
9241+
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
9242+
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
9243+
9244+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
9245+
9246+
float start;
9247+
float stop;
9248+
float step;
9249+
memcpy(&start, (float *)dst->op_params + 0, sizeof(float));
9250+
memcpy(&stop, (float *)dst->op_params + 1, sizeof(float));
9251+
memcpy(&step, (float *)dst->op_params + 2, sizeof(float));
9252+
9253+
int64_t steps = (int64_t)ceil((stop - start) / step);
9254+
GGML_ASSERT(ggml_nelements(dst) == steps);
9255+
9256+
arange_f32_cuda(dst_dd, dst->ne[0], start, step, main_stream);
9257+
9258+
(void) src0;
9259+
(void) src1;
9260+
(void) src0_dd;
9261+
(void) src1_dd;
9262+
}
9263+
9264+
static void ggml_cuda_op_timestep_embedding(
9265+
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
9266+
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
9267+
9268+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
9269+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
9270+
9271+
const int dim = dst->op_params[0];
9272+
const int max_period = dst->op_params[1];
9273+
9274+
timestep_embedding_f32_cuda(src0_dd, dst_dd, src0->ne[0], dst->nb[1], dim, max_period, main_stream);
91779275

91789276
(void) src1;
91799277
(void) dst;
@@ -10458,6 +10556,45 @@ static void ggml_cuda_pad(const ggml_tensor * src0, const ggml_tensor * src1, gg
1045810556
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_pad);
1045910557
}
1046010558

10559+
static void ggml_cuda_arange(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
10560+
ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
10561+
10562+
const bool dst_on_device = dst->backend == GGML_BACKEND_TYPE_GPU;
10563+
10564+
// dd = data device
10565+
float * src0_ddf = nullptr;
10566+
float * src1_ddf = nullptr;
10567+
float * dst_ddf = nullptr;
10568+
10569+
cuda_pool_alloc<float> dst_f;
10570+
10571+
ggml_cuda_set_device(g_main_device);
10572+
cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
10573+
10574+
if (dst_on_device) {
10575+
dst_ddf = (float *) dst_extra->data_device[g_main_device];
10576+
} else {
10577+
dst_ddf = dst_f.alloc(ggml_nelements(dst));
10578+
}
10579+
10580+
// do the computation
10581+
ggml_cuda_op_arange(src0, src1, dst, src0_ddf, src1_ddf, dst_ddf, main_stream);
10582+
CUDA_CHECK(cudaGetLastError());
10583+
10584+
// copy dst to host if necessary
10585+
if (!dst_on_device) {
10586+
CUDA_CHECK(cudaMemcpyAsync(dst->data, dst_ddf, ggml_nbytes(dst), cudaMemcpyDeviceToHost, main_stream));
10587+
}
10588+
10589+
if (dst->backend == GGML_BACKEND_TYPE_CPU) {
10590+
CUDA_CHECK(cudaDeviceSynchronize());
10591+
}
10592+
}
10593+
10594+
static void ggml_cuda_timestep_embedding(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
10595+
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_timestep_embedding);
10596+
}
10597+
1046110598
static void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
1046210599
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_rms_norm);
1046310600
}
@@ -11358,6 +11495,12 @@ GGML_CALL bool ggml_cuda_compute_forward(struct ggml_compute_params * params, st
1135811495
case GGML_OP_PAD:
1135911496
func = ggml_cuda_pad;
1136011497
break;
11498+
case GGML_OP_ARANGE:
11499+
func = ggml_cuda_arange;
11500+
break;
11501+
case GGML_OP_TIMESTEP_EMBEDDING:
11502+
func = ggml_cuda_timestep_embedding;
11503+
break;
1136111504
case GGML_OP_LEAKY_RELU:
1136211505
func = ggml_cuda_leaky_relu;
1136311506
break;
@@ -12253,6 +12396,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
1225312396
case GGML_OP_GROUP_NORM:
1225412397
case GGML_OP_UPSCALE:
1225512398
case GGML_OP_PAD:
12399+
case GGML_OP_ARANGE:
12400+
case GGML_OP_TIMESTEP_EMBEDDING:
1225612401
case GGML_OP_LEAKY_RELU:
1225712402
return true;
1225812403
default:

0 commit comments

Comments
 (0)