Spaces:
Running
Running
Commit
·
6df9571
1
Parent(s):
acfd94f
CUDA: fix Volta FlashAttention logic (llama/11615)
Browse files
ggml/src/ggml-cuda/fattn-wmma-f16.cu
CHANGED
|
@@ -561,7 +561,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_ten
|
|
| 561 |
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
|
| 562 |
break;
|
| 563 |
// case 256:
|
| 564 |
-
// ggml_cuda_flash_attn_ext_wmma_f16_case<
|
| 565 |
// break;
|
| 566 |
default:
|
| 567 |
GGML_ABORT("fatal error");
|
|
|
|
| 561 |
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
|
| 562 |
break;
|
| 563 |
// case 256:
|
| 564 |
+
// ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst);
|
| 565 |
// break;
|
| 566 |
default:
|
| 567 |
GGML_ABORT("fatal error");
|
ggml/src/ggml-cuda/fattn.cu
CHANGED
|
@@ -235,7 +235,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|
| 235 |
return;
|
| 236 |
}
|
| 237 |
|
| 238 |
-
if (!
|
| 239 |
if (prec == GGML_PREC_DEFAULT) {
|
| 240 |
if (Q->ne[1] <= 8) {
|
| 241 |
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
|
@@ -265,6 +265,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|
| 265 |
// The MMA implementation needs Turing or newer, use the old WMMA code for Volta:
|
| 266 |
if (cc == GGML_CUDA_CC_VOLTA) {
|
| 267 |
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
|
|
|
|
| 268 |
}
|
| 269 |
|
| 270 |
ggml_cuda_flash_attn_ext_mma_f16(ctx, dst);
|
|
|
|
| 235 |
return;
|
| 236 |
}
|
| 237 |
|
| 238 |
+
if (!fp16_mma_available(cc)) {
|
| 239 |
if (prec == GGML_PREC_DEFAULT) {
|
| 240 |
if (Q->ne[1] <= 8) {
|
| 241 |
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
|
|
|
| 265 |
// The MMA implementation needs Turing or newer, use the old WMMA code for Volta:
|
| 266 |
if (cc == GGML_CUDA_CC_VOLTA) {
|
| 267 |
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
|
| 268 |
+
return;
|
| 269 |
}
|
| 270 |
|
| 271 |
ggml_cuda_flash_attn_ext_mma_f16(ctx, dst);
|