Skip to content

Commit 0b4be4c

Browse files
CUDA: fix FTZ in FA for Gemma 3 (#13991)
1 parent e0e806f commit 0b4be4c

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

ggml/src/ggml-cuda/fattn-mma-f16.cuh

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -652,9 +652,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
652652
float KQ_max_scale[cols_per_thread];
653653
#pragma unroll
654654
for (int col = 0; col < cols_per_thread; ++col) {
655-
KQ_max_scale[col] = expf(KQ_max[col] - KQ_max_new[col]);
655+
const float KQ_max_diff = KQ_max[col] - KQ_max_new[col];
656+
KQ_max_scale[col] = expf(KQ_max_diff);
656657
KQ_max[col] = KQ_max_new[col];
657658

659+
*((uint32_t *) &KQ_max_scale[col]) *= KQ_max_diff >= SOFTMAX_FTZ_THRESHOLD;
660+
658661
// Scale previous KQ_rowsum to account for a potential increase in KQ_max:
659662
KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_rowsum_add[col];
660663
}

0 commit comments

Comments
 (0)