We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent e0e806f commit 0b4be4cCopy full SHA for 0b4be4c
ggml/src/ggml-cuda/fattn-mma-f16.cuh
@@ -652,9 +652,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
652
float KQ_max_scale[cols_per_thread];
653
#pragma unroll
654
for (int col = 0; col < cols_per_thread; ++col) {
655
- KQ_max_scale[col] = expf(KQ_max[col] - KQ_max_new[col]);
+ const float KQ_max_diff = KQ_max[col] - KQ_max_new[col];
656
+ KQ_max_scale[col] = expf(KQ_max_diff);
657
KQ_max[col] = KQ_max_new[col];
658
659
+ *((uint32_t *) &KQ_max_scale[col]) *= KQ_max_diff >= SOFTMAX_FTZ_THRESHOLD;
660
+
661
// Scale previous KQ_rowsum to account for a potential increase in KQ_max:
662
KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_rowsum_add[col];
663
}
0 commit comments