Spaces:
Running
Running
Commit
·
6e3a7b6
1
Parent(s):
7cdf9cd
CUDA: fix negative KV_max values in FA (llama/15321)
Browse files
ggml/src/ggml-cuda/fattn-common.cuh
CHANGED
|
@@ -539,11 +539,15 @@ static __global__ void flash_attn_mask_to_KV_max(
|
|
| 539 |
all_inf = warp_reduce_all(all_inf);
|
| 540 |
|
| 541 |
if (!all_inf) {
|
| 542 |
-
KV_max_sj += FATTN_KQ_STRIDE;
|
| 543 |
break;
|
| 544 |
}
|
| 545 |
}
|
| 546 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 547 |
if (threadIdx.x != 0) {
|
| 548 |
return;
|
| 549 |
}
|
|
|
|
| 539 |
all_inf = warp_reduce_all(all_inf);
|
| 540 |
|
| 541 |
if (!all_inf) {
|
|
|
|
| 542 |
break;
|
| 543 |
}
|
| 544 |
}
|
| 545 |
|
| 546 |
+
// If the break in the loop was not triggered, KV_max_sj is now -FATTN_KQ_STRIDE.
|
| 547 |
+
// If the break was triggered it's the lower edge of the tile with the first non-masked values.
|
| 548 |
+
// In either case, walk back the decrementation by FATTN_KQ_STRIDE.
|
| 549 |
+
KV_max_sj += FATTN_KQ_STRIDE;
|
| 550 |
+
|
| 551 |
if (threadIdx.x != 0) {
|
| 552 |
return;
|
| 553 |
}
|