Spaces:
Running
Running
| template <int DKQ, int DV, int ncols2> | |
| static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { | |
| const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; | |
| const ggml_tensor * Q = dst->src[0]; | |
| if constexpr (ncols2 <= 8) { | |
| if (Q->ne[1] <= 8/ncols2) { | |
| ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 8/ncols2, ncols2>(ctx, dst); | |
| return; | |
| } | |
| } | |
| if (Q->ne[1] <= 16/ncols2) { | |
| ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 16/ncols2, ncols2>(ctx, dst); | |
| return; | |
| } | |
| if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING || Q->ne[1] <= 32/ncols2) { | |
| ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 32/ncols2, ncols2>(ctx, dst); | |
| return; | |
| } | |
| ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 64/ncols2, ncols2>(ctx, dst); | |
| } | |
| template <int DKQ, int DV> | |
| static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { | |
| const ggml_tensor * KQV = dst; | |
| const ggml_tensor * Q = dst->src[0]; | |
| const ggml_tensor * K = dst->src[1]; | |
| const ggml_tensor * mask = dst->src[3]; | |
| float max_bias = 0.0f; | |
| memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); | |
| const bool use_gqa_opt = mask && max_bias == 0.0f; | |
| GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); | |
| const int gqa_ratio = Q->ne[2] / K->ne[2]; | |
| if (use_gqa_opt && gqa_ratio % 8 == 0) { | |
| ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 8>(ctx, dst); | |
| return; | |
| } | |
| if (use_gqa_opt && gqa_ratio % 4 == 0) { | |
| ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 4>(ctx, dst); | |
| return; | |
| } | |
| if (use_gqa_opt && gqa_ratio % 2 == 0) { | |
| ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 2>(ctx, dst); | |
| return; | |
| } | |
| ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 1>(ctx, dst); | |
| } | |
| static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { | |
| const ggml_tensor * KQV = dst; | |
| const ggml_tensor * Q = dst->src[0]; | |
| const ggml_tensor * K = dst->src[1]; | |
| const ggml_tensor * V = dst->src[2]; | |
| const ggml_tensor * mask = dst->src[3]; | |
| switch (Q->ne[0]) { | |
| case 64: | |
| GGML_ASSERT(V->ne[0] == 64); | |
| ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2< 64, 64>(ctx, dst); | |
| break; | |
| case 80: | |
| GGML_ASSERT(V->ne[0] == 80); | |
| ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2< 80, 80>(ctx, dst); | |
| break; | |
| case 96: | |
| GGML_ASSERT(V->ne[0] == 96); | |
| ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2< 96, 96>(ctx, dst); | |
| break; | |
| case 112: | |
| GGML_ASSERT(V->ne[0] == 112); | |
| ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<112, 112>(ctx, dst); | |
| break; | |
| case 128: | |
| GGML_ASSERT(V->ne[0] == 128); | |
| ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<128, 128>(ctx, dst); | |
| break; | |
| case 256: | |
| GGML_ASSERT(V->ne[0] == 256); | |
| ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst); | |
| break; | |
| case 576: { | |
| // For Deepseek, go straight to the ncols1 switch to avoid compiling unnecessary kernels. | |
| GGML_ASSERT(V->ne[0] == 512); | |
| float max_bias = 0.0f; | |
| memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); | |
| const bool use_gqa_opt = mask && max_bias == 0.0f; | |
| GGML_ASSERT(use_gqa_opt); | |
| GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); | |
| const int gqa_ratio = Q->ne[2] / K->ne[2]; | |
| GGML_ASSERT(gqa_ratio % 16 == 0); | |
| ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst); | |
| } break; | |
| default: | |
| GGML_ABORT("fatal error"); | |
| break; | |
| } | |
| } | |
| static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { | |
| ggml_tensor * Q = dst->src[0]; | |
| ggml_tensor * K = dst->src[1]; | |
| ggml_tensor * V = dst->src[2]; | |
| FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0) | |
| FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1) | |
| FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0) | |
| FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1) | |
| FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0) | |
| FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16 ) | |
| FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0) | |
| FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0) | |
| FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0) | |
| FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0) | |
| FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0) | |
| FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0) | |
| FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1) | |
| FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1) | |
| FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1) | |
| FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1) | |
| FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1) | |
| FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1) | |
| FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0) | |
| FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0) | |
| FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0) | |
| FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0) | |
| FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0) | |
| FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0) | |
| FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1) | |
| FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1) | |
| FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1) | |
| FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1) | |
| FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1) | |
| FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1) | |
| FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0) | |
| FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0) | |
| FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0) | |
| FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0) | |
| FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) | |
| FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0) | |
| FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16) | |
| FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16) | |
| FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16) | |
| FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16) | |
| FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16) | |
| FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16) | |
| FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16) | |
| FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0) | |
| FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) | |
| FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16) | |
| FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16) | |
| FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16) | |
| on_no_fattn_vec_case(Q->ne[0]); | |
| } | |
| static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { | |
| ggml_tensor * Q = dst->src[0]; | |
| ggml_tensor * K = dst->src[1]; | |
| ggml_tensor * V = dst->src[2]; | |
| FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0) | |
| FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1) | |
| FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0) | |
| FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1) | |
| FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0) | |
| FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16) | |
| FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0) | |
| FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0) | |
| FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0) | |
| FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0) | |
| FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0) | |
| FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0) | |
| FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1) | |
| FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1) | |
| FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1) | |
| FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1) | |
| FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1) | |
| FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1) | |
| FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0) | |
| FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0) | |
| FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0) | |
| FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0) | |
| FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0) | |
| FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0) | |
| FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1) | |
| FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1) | |
| FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1) | |
| FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1) | |
| FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1) | |
| FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1) | |
| FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0) | |
| FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0) | |
| FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0) | |
| FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0) | |
| FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) | |
| FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0) | |
| FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16) | |
| FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16) | |
| FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16) | |
| FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16) | |
| FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16) | |
| FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16) | |
| FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16) | |
| FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0) | |
| FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) | |
| FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16) | |
| FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16) | |
| FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16) | |
| on_no_fattn_vec_case(Q->ne[0]); | |
| } | |
| void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { | |
| const ggml_tensor * KQV = dst; | |
| const ggml_tensor * Q = dst->src[0]; | |
| const ggml_tensor * K = dst->src[1]; | |
| const ggml_tensor * V = dst->src[2]; | |
| const ggml_tensor * mask = dst->src[3]; | |
| ggml_cuda_set_device(ctx.device); | |
| const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; | |
| const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size; | |
| const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV); | |
| if (GGML_CUDA_CC_IS_AMD(cc) && fp16_mma_available(cc)) { | |
| ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst); | |
| return; | |
| } | |
| if (!fast_fp16_available(cc)) { | |
| if (Q->ne[1] <= 8 || Q->ne[0] == 256) { | |
| ggml_cuda_flash_attn_ext_vec_f32(ctx, dst); | |
| } else { | |
| ggml_cuda_flash_attn_ext_tile_f32(ctx, dst); | |
| } | |
| return; | |
| } | |
| if (!fp16_mma_available(cc)) { | |
| if (prec == GGML_PREC_DEFAULT) { | |
| if (Q->ne[1] <= 8 || Q->ne[0] == 256) { | |
| ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); | |
| } else { | |
| ggml_cuda_flash_attn_ext_tile_f16(ctx, dst); | |
| } | |
| } else { | |
| if (Q->ne[1] <= 8 || Q->ne[0] == 256) { | |
| ggml_cuda_flash_attn_ext_vec_f32(ctx, dst); | |
| } else { | |
| ggml_cuda_flash_attn_ext_tile_f32(ctx, dst); | |
| } | |
| } | |
| return; | |
| } | |
| const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations | |
| const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16; | |
| const bool mma_faster_for_rtx4000 = Q->ne[3] > 1 || (Q->ne[2] > 4*K->ne[2] && K->ne[1] >= 8192); | |
| const bool mma_faster_for_bs1 = turing_mma_available(cc) && gqa_opt_applies && !mma_needs_data_conversion && | |
| (cc < GGML_CUDA_CC_ADA_LOVELACE || mma_faster_for_rtx4000); | |
| const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % (2*warp_size) == 0; | |
| if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) { | |
| if (prec == GGML_PREC_DEFAULT) { | |
| ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); | |
| } else { | |
| ggml_cuda_flash_attn_ext_vec_f32(ctx, dst); | |
| } | |
| return; | |
| } | |
| // The MMA implementation needs Turing or newer, use the old WMMA code for Volta: | |
| if (fp16_mma_available(cc) && !turing_mma_available(cc)) { | |
| ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst); | |
| return; | |
| } | |
| ggml_cuda_flash_attn_ext_mma_f16(ctx, dst); | |
| } | |