JohannesGaessler commited on
Commit
6e3a7b6
·
1 Parent(s): 7cdf9cd

CUDA: fix negative KV_max values in FA (llama/15321)

Browse files
ggml/src/ggml-cuda/fattn-common.cuh CHANGED
@@ -539,11 +539,15 @@ static __global__ void flash_attn_mask_to_KV_max(
539
  all_inf = warp_reduce_all(all_inf);
540
 
541
  if (!all_inf) {
542
- KV_max_sj += FATTN_KQ_STRIDE;
543
  break;
544
  }
545
  }
546
 
 
 
 
 
 
547
  if (threadIdx.x != 0) {
548
  return;
549
  }
 
539
  all_inf = warp_reduce_all(all_inf);
540
 
541
  if (!all_inf) {
 
542
  break;
543
  }
544
  }
545
 
546
+ // If the break in the loop was not triggered, KV_max_sj is now -FATTN_KQ_STRIDE.
547
+ // If the break was triggered it's the lower edge of the tile with the first non-masked values.
548
+ // In either case, walk back the decrementation by FATTN_KQ_STRIDE.
549
+ KV_max_sj += FATTN_KQ_STRIDE;
550
+
551
  if (threadIdx.x != 0) {
552
  return;
553
  }