Skip to content

ggml : implement REGLU/GEGLU/SWIGLU ops #14158

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 16 commits into
base: master
Choose a base branch
from

Conversation

CISC
Copy link
Collaborator

@CISC CISC commented Jun 12, 2025

Implement REGLU/GEGLU/SWIGLU ops to avoid unnecessary tensor duplications and a little more efficient execution by combining ops in one.

Only CPU and CUDA right now, help needed to complete other backends!

@CISC CISC added the help wanted Extra attention is needed label Jun 12, 2025
@CISC CISC requested a review from ggerganov June 12, 2025 23:25
@github-actions github-actions bot added testing Everything test related Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Jun 12, 2025
Copy link
Member

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I missed that these ops change the shape of the input tensor.

I think it would be better to introduce:

enum ggml_glu_op {
    GGML_GLU_OP_REGLU,
    GGML_GLU_OP_GEGLU,
    GGML_GLU_OP_SWIGLU,
};

// similar to ggml_unary()
GGML_API struct ggml_tensor * ggml_glu(
        struct ggml_context * ctx,
         struct ggml_tensor * a,
           enum ggml_glu_op   op);

// these simply call ggml_glu()
GGML_API struct ggml_tensor * ggml_reglu(
        struct ggml_context * ctx,
        struct ggml_tensor  * a);

GGML_API struct ggml_tensor * ggml_geglu(
        struct ggml_context * ctx,
        struct ggml_tensor  * a);

GGML_API struct ggml_tensor * ggml_swiglu(
        struct ggml_context * ctx,
        struct ggml_tensor  * a);

@CISC CISC changed the title ggml : implement unary REGLU/GEGLU/SWIGLU ops ggml : implement REGLU/GEGLU/SWIGLU ops Jun 13, 2025
@CISC CISC requested a review from ggerganov June 13, 2025 08:23
Copy link
Member

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hope we don't forget to implement these in the rest of the backends.

Adding @JohannesGaessler for review of the CUDA changes.

@ggerganov
Copy link
Member

Only CPU and CUDA right now, help needed to complete other backends!

Yes, let's add the rest of the backends first before merging. At least Metal and Vulkan.

@JohannesGaessler
Copy link
Collaborator

More generally, I've been thinking that it would be useful to have something like a backend-specific graph optimization step in ggml. That way you could do things like fuse tensors only if the fused tensor is supported by the backend and only if using it makes sense given the tensor shapes.

@CISC
Copy link
Collaborator Author

CISC commented Jun 13, 2025

Only CPU and CUDA right now, help needed to complete other backends!

Yes, let's add the rest of the backends first before merging. At least Metal and Vulkan.

Any suggestions on who could help with that?

@github-actions github-actions bot added the Apple Metal https://en.wikipedia.org/wiki/Metal_(API) label Jun 13, 2025
struct ggml_context * ctx,
struct ggml_tensor * a);

GGML_API struct ggml_tensor * ggml_swiglu(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just want to note that I have been observing one variants of swiglu. it's used by ultravox, which sigmoid the second half of the vector instead of the first half

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, interesting, worth adding a parameter for, or best just handling in conversion?
https://huggingface.co/fixie-ai/ultravox-v0_5-llama-3_3-70b/blob/main/ultravox_model.py#L701-L704

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be nice to have a param since the GGUFs are already on the internet. Haven't thought about permuting the FFN up tensor before, nice suggestion

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added swapped variants.

@ggerganov I didn't dare update metal code, so needs to be implemented there too. :)

@JohannesGaessler
Copy link
Collaborator

@0cc4m @jeffbolznv are either of you interested in a Vulkan implementation?

@0cc4m
Copy link
Collaborator

0cc4m commented Jun 13, 2025

I can look into it tomorrow.

@JohannesGaessler
Copy link
Collaborator

CUDA performance test:

GPU Model Microbatch size Test t/s master t/s cisc/unary-reglu-geglu-swiglu Speedup
RTX 4090 chatglm 9B Q4_0 1 pp512 157.48 160.49 1.02
RTX 4090 chatglm 9B Q4_0 2 pp512 268.16 276.78 1.03
RTX 4090 chatglm 9B Q4_0 4 pp512 517.41 535.36 1.03
RTX 4090 chatglm 9B Q4_0 8 pp512 826.69 855.46 1.03
RTX 4090 chatglm 9B Q4_0 16 pp512 1407.13 1453.62 1.03
RTX 4090 chatglm 9B Q4_0 32 pp512 2545.45 2664.80 1.05
RTX 4090 chatglm 9B Q4_0 64 pp512 4414.61 4704.57 1.07
RTX 4090 chatglm 9B Q4_0 128 pp512 6467.60 7028.01 1.09
RTX 4090 chatglm 9B Q4_0 256 pp512 8670.62 9451.16 1.09
RTX 4090 chatglm 9B Q4_0 512 pp512 9842.99 10832.14 1.10

Also a plot of the same data using #14169 :

plot

@ggerganov
Copy link
Member

CUDA performance test:

Huh, I didn't expect the benefit to be that much. Interesting.

@jeffbolznv
Copy link
Collaborator

Huh, I didn't expect the benefit to be that much. Interesting.

Everything other than mat-mat mul is either bandwidth or small dispatch limited. Fusion is a big opportunity. We should reopen discussions about how to enable more types of fusion.

@CISC
Copy link
Collaborator Author

CISC commented Jun 13, 2025

CUDA performance test:

Nice! Will be interesting to see numbers on other backends as well...

@CISC
Copy link
Collaborator Author

CISC commented Jun 13, 2025

Hmmm, it just occurred to me that we should be able to (now that I pass along a pointer to the gate separately) perform these ops on models with separate ffn_up/gate tensors too by conditionally setting src[1].

struct ggml_context * ctx,
struct ggml_tensor * a);

GGML_API struct ggml_tensor * ggml_geglu(
Copy link
Collaborator

@ngxson ngxson Jun 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tbh I don't even know why geglu was added in the first place. It doesn't seem to be used by any models. And to make matter worse, the PR where it was added has no useful description: #14074

So I wonder if we actually need to implement it as a kernel. The current kernel use tanh approximation, but in practice, there can be many different approximations for gelu op.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've seen several, and in fact we already support a few (Gemma, DeepSeekV1, Jina-Bert and T5), it's just that the gate is split (some at conversion because we didn't have the op).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I wonder if we actually need to implement it as a kernel. The current kernel use tanh approximation, but in practice, there can be many different approximations for gelu op.

It's pretty easy adding different GLU ops (and in CUDA I even reuse the original op), adding GEGLU_ERF if necessary shouldn't be a problem.

@github-actions github-actions bot added the Vulkan Issues specific to the Vulkan backend label Jun 14, 2025
@0cc4m
Copy link
Collaborator

0cc4m commented Jun 14, 2025

I implemented Vulkan shaders for the new ops.

@qnixsynapse
Copy link
Collaborator

Interesting.. I tried implementing for SYCL, saw little improvement. When I saw the graph logs, it wasn't using the fused kernels for llama 3.2 3B.

[SYCL][OP] call ggml_sycl_mul_mat: dst='ffn_gate-19':type=f32;ne=[8192, 1, 1, 1];nb=[4, 32768, 32768, 32768]	src0='blk.19.ffn_gate.weight':type=f16;ne=[3072, 8192, 1, 1];nb=[2, 6144, 50331648, 50331648]	src1='ffn_norm-19':type=f32;ne=[3072, 1, 1, 1];nb=[4, 12288, 12288, 12288]
[SYCL][OP] call ggml_sycl_op_dequantize_mul_mat_vec/to_fp16_sycl: dst='ffn_gate-19':type=f32;ne=[8192, 1, 1, 1];nb=[4, 32768, 32768, 32768]	src0='blk.19.ffn_gate.weight':type=f16;ne=[3072, 8192, 1, 1];nb=[2, 6144, 50331648, 50331648]	src1='ffn_norm-19':type=f32;ne=[3072, 1, 1, 1];nb=[4, 12288, 12288, 12288] : converting src1 to fp16
[SYCL][OP] call ggml_sycl_op_dequantize_mul_mat_vec/to_fp16_sycl done
[SYCL][OP] call ggml_sycl_mul_mat done
[SYCL][OP] call ggml_sycl_silu: dst='ffn_silu-19':type=f32;ne=[8192, 1, 1, 1];nb=[4, 32768, 32768, 32768]	src0='ffn_gate-19':type=f32;ne=[8192, 1, 1, 1];nb=[4, 32768, 32768, 32768]
[SYCL][OP] call ggml_sycl_silu done
[SYCL][OP] call ggml_sycl_mul_mat: dst='ffn_up-19':type=f32;ne=[8192, 1, 1, 1];nb=[4, 32768, 32768, 32768]	src0='blk.19.ffn_up.weight':type=f16;ne=[3072, 8192, 1, 1];nb=[2, 6144, 50331648, 50331648]	src1='ffn_norm-19':type=f32;ne=[3072, 1, 1, 1];nb=[4, 12288, 12288, 12288]
[SYCL][OP] call ggml_sycl_op_dequantize_mul_mat_vec/to_fp16_sycl: dst='ffn_up-19':type=f32;ne=[8192, 1, 1, 1];nb=[4, 32768, 32768, 32768]	src0='blk.19.ffn_up.weight':type=f16;ne=[3072, 8192, 1, 1];nb=[2, 6144, 50331648, 50331648]	src1='ffn_norm-19':type=f32;ne=[3072, 1, 1, 1];nb=[4, 12288, 12288, 12288] : converting src1 to fp16
[SYCL][OP] call ggml_sycl_op_dequantize_mul_mat_vec/to_fp16_sycl done
[SYCL][OP] call ggml_sycl_mul_mat done
[SYCL][OP] call ggml_sycl_mul: dst='ffn_gate_par-19':type=f32;ne=[8192, 1, 1, 1];nb=[4, 32768, 32768, 32768]	src0='ffn_silu-19':type=f32;ne=[8192, 1, 1, 1];nb=[4, 32768, 32768, 32768]	src1='ffn_up-19':type=f32;ne=[8192, 1, 1, 1];nb=[4, 32768, 32768, 32768]
[SYCL][OP] call ggml_sycl_mul done
[SYCL][OP] call ggml_sycl_mul_mat: dst='ffn_out-19':type=f32;ne=[3072, 1, 1, 1];nb=[4, 12288, 12288, 12288]	src0='blk.19.ffn_down.weight':type=f16;ne=[8192, 3072, 1, 1];nb=[2, 16384, 50331648, 50331648]	src1='ffn_gate_par-19':type=f32;ne=[8192, 1, 1, 1];nb=[4, 32768, 32768, 32768]
[SYCL][OP] call ggml_sycl_op_dequantize_mul_mat_vec/to_fp16_sycl: dst='ffn_out-19':type=f32;ne=[3072, 1, 1, 1];nb=[4, 12288, 12288, 12288]	src0='blk.19.ffn_down.weight':type=f16;ne=[8192, 3072, 1, 1];nb=[2, 16384, 50331648, 50331648]	src1='ffn_gate_par-19':type=f32;ne=[8192, 1, 1, 1];nb=[4, 32768, 32768, 32768] : converting src1 to fp16
[SYCL][OP] call ggml_sycl_op_dequantize_mul_mat_vec/to_fp16_sycl done
[SYCL][OP] call ggml_sycl_mul_mat done
[SYCL][OP] call ggml_sycl_add: dst='l_out-19':type=f32;ne=[3072, 1, 1, 1];nb=[4, 12288, 12288, 12288]	src0='ffn_out-19':type=f32;ne=[3072, 1, 1, 1];nb=[4, 12288, 12288, 12288]	src1='ffn_inp-19':type=f32;ne=[3072, 1, 1, 1];nb=[4, 12288, 12288, 12288]
[SYCL][OP] call ggml_sycl_add done

I am using from this branch.

@CISC
Copy link
Collaborator Author

CISC commented Jun 14, 2025

Interesting.. I tried implementing for SYCL, saw little improvement. When I saw the graph logs, it wasn't using the fused kernels for llama 3.2 3B.

That's normal, llama 3.2 doesn't have a single up+gate does it?

@qnixsynapse
Copy link
Collaborator

qnixsynapse commented Jun 14, 2025

IIRC, it has SWIGLU. But I didn't check if it is using a single up+gate or not.

Edit: Nevermind, Need to check #14181 I guess

@github-actions github-actions bot added the SYCL https://en.wikipedia.org/wiki/SYCL - GPU programming language label Jun 14, 2025
@CISC
Copy link
Collaborator Author

CISC commented Jun 14, 2025

Edit: Nevermind, Need to check #14181 I guess

Yep. :)

@qnixsynapse
Copy link
Collaborator

Yep. :)

Please merge this PR first so that I can adjust the existing kernels for split up and gate. :)

I will deduplicate the SYCL code then.

@CISC
Copy link
Collaborator Author

CISC commented Jun 14, 2025

Please merge this PR first so that I can adjust the existing kernels for split up and gate. :)

The plan is to merge #14181 into this one once @ggerganov signs off on it, then backends can be updated, and once all tests go green, merge into master.

@CISC
Copy link
Collaborator Author

CISC commented Jun 14, 2025

@qnixsynapse If you want you can bring the other branch up-to-date and add your changes there.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Apple Metal https://en.wikipedia.org/wiki/Metal_(API) ggml changes relating to the ggml tensor library for machine learning help wanted Extra attention is needed Nvidia GPU Issues specific to Nvidia GPUs SYCL https://en.wikipedia.org/wiki/SYCL - GPU programming language testing Everything test related Vulkan Issues specific to the Vulkan backend
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants