JohannesGaessler commited on
Commit
6df9571
·
1 Parent(s): acfd94f

CUDA: fix Volta FlashAttention logic (llama/11615)

Browse files
ggml/src/ggml-cuda/fattn-wmma-f16.cu CHANGED
@@ -561,7 +561,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_ten
561
  ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
562
  break;
563
  // case 256:
564
- // ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
565
  // break;
566
  default:
567
  GGML_ABORT("fatal error");
 
561
  ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
562
  break;
563
  // case 256:
564
+ // ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst);
565
  // break;
566
  default:
567
  GGML_ABORT("fatal error");
ggml/src/ggml-cuda/fattn.cu CHANGED
@@ -235,7 +235,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
235
  return;
236
  }
237
 
238
- if (!new_mma_available(cc)) {
239
  if (prec == GGML_PREC_DEFAULT) {
240
  if (Q->ne[1] <= 8) {
241
  ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
@@ -265,6 +265,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
265
  // The MMA implementation needs Turing or newer, use the old WMMA code for Volta:
266
  if (cc == GGML_CUDA_CC_VOLTA) {
267
  ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
 
268
  }
269
 
270
  ggml_cuda_flash_attn_ext_mma_f16(ctx, dst);
 
235
  return;
236
  }
237
 
238
+ if (!fp16_mma_available(cc)) {
239
  if (prec == GGML_PREC_DEFAULT) {
240
  if (Q->ne[1] <= 8) {
241
  ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
 
265
  // The MMA implementation needs Turing or newer, use the old WMMA code for Volta:
266
  if (cc == GGML_CUDA_CC_VOLTA) {
267
  ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
268
+ return;
269
  }
270
 
271
  ggml_cuda_flash_attn_ext_mma_f16(ctx, dst);