JohannesGaessler commited on
Commit
ace16dc
·
1 Parent(s): f8fd66d

CUDA: faster Deepseek FA, add Turing support (llama/13435)

Browse files
ggml/src/ggml-cuda/fattn-common.cuh CHANGED
@@ -678,10 +678,14 @@ void launch_fattn(
678
  ) {
679
  constexpr int ncols = ncols1 * ncols2;
680
 
 
 
681
  const ggml_tensor * Q = dst->src[0];
682
  const ggml_tensor * K = dst->src[1];
683
  const ggml_tensor * V = dst->src[2];
684
 
 
 
685
  const ggml_tensor * mask = dst->src[3];
686
 
687
  ggml_tensor * KQV = dst;
@@ -689,6 +693,10 @@ void launch_fattn(
689
  GGML_ASSERT(Q->type == GGML_TYPE_F32);
690
  GGML_ASSERT(KQV->type == GGML_TYPE_F32);
691
 
 
 
 
 
692
  GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
693
  GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
694
  "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
@@ -713,10 +721,10 @@ void launch_fattn(
713
  size_t nb12 = K->nb[2];
714
  size_t nb13 = K->nb[3];
715
 
716
- const char * V_data = (const char *) V->data;
717
- size_t nb21 = V->nb[1];
718
- size_t nb22 = V->nb[2];
719
- size_t nb23 = V->nb[3];
720
 
721
  if (need_f16_K && K->type != GGML_TYPE_F16) {
722
  GGML_ASSERT(ggml_is_contiguously_allocated(K));
@@ -733,7 +741,7 @@ void launch_fattn(
733
  nb13 = nb13*bs*sizeof(half)/ts;
734
  }
735
 
736
- if (need_f16_V && V->type != GGML_TYPE_F16) {
737
  GGML_ASSERT(ggml_is_contiguously_allocated(V));
738
  V_f16.alloc(ggml_nelements(V));
739
  to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
 
678
  ) {
679
  constexpr int ncols = ncols1 * ncols2;
680
 
681
+ const bool is_mla = DV == 512; // TODO better parameterization
682
+
683
  const ggml_tensor * Q = dst->src[0];
684
  const ggml_tensor * K = dst->src[1];
685
  const ggml_tensor * V = dst->src[2];
686
 
687
+ GGML_ASSERT(V || is_mla);
688
+
689
  const ggml_tensor * mask = dst->src[3];
690
 
691
  ggml_tensor * KQV = dst;
 
693
  GGML_ASSERT(Q->type == GGML_TYPE_F32);
694
  GGML_ASSERT(KQV->type == GGML_TYPE_F32);
695
 
696
+ GGML_ASSERT( Q->nb[0] == ggml_element_size(Q));
697
+ GGML_ASSERT( K->nb[0] == ggml_element_size(K));
698
+ GGML_ASSERT(!V || V->nb[0] == ggml_element_size(V));
699
+
700
  GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
701
  GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
702
  "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
 
721
  size_t nb12 = K->nb[2];
722
  size_t nb13 = K->nb[3];
723
 
724
+ const char * V_data = V ? (const char *) V->data : nullptr;
725
+ size_t nb21 = V ? V->nb[1] : nb11;
726
+ size_t nb22 = V ? V->nb[2] : nb12;
727
+ size_t nb23 = V ? V->nb[3] : nb13;
728
 
729
  if (need_f16_K && K->type != GGML_TYPE_F16) {
730
  GGML_ASSERT(ggml_is_contiguously_allocated(K));
 
741
  nb13 = nb13*bs*sizeof(half)/ts;
742
  }
743
 
744
+ if (V && need_f16_V && V->type != GGML_TYPE_F16) {
745
  GGML_ASSERT(ggml_is_contiguously_allocated(V));
746
  V_f16.alloc(ggml_nelements(V));
747
  to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
ggml/src/ggml-cuda/fattn-mma-f16.cuh CHANGED
@@ -33,9 +33,30 @@ struct fattn_mma_f16_config< 64, 64> {
33
  static constexpr int nwarps_max = 4;
34
  static constexpr bool Q_in_reg = true;
35
  static constexpr int nstages_target = 2;
36
- static constexpr int nbatch_K2 = 32;
37
- static constexpr int nbatch_V2 = 32;
38
- static constexpr int nbatch_combine = 32;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  };
40
 
41
  template <>
@@ -44,9 +65,30 @@ struct fattn_mma_f16_config< 80, 80> {
44
  static constexpr int nwarps_max = 4;
45
  static constexpr bool Q_in_reg = true;
46
  static constexpr int nstages_target = 2;
47
- static constexpr int nbatch_K2 = 40;
48
- static constexpr int nbatch_V2 = 40;
49
- static constexpr int nbatch_combine = 40;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  };
51
 
52
  template <>
@@ -55,9 +97,30 @@ struct fattn_mma_f16_config< 96, 96> {
55
  static constexpr int nwarps_max = 4;
56
  static constexpr bool Q_in_reg = true;
57
  static constexpr int nstages_target = 2;
58
- static constexpr int nbatch_K2 = 48;
59
- static constexpr int nbatch_V2 = 48;
60
- static constexpr int nbatch_combine = 48;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  };
62
 
63
  template <>
@@ -66,9 +129,30 @@ struct fattn_mma_f16_config<112, 112> {
66
  static constexpr int nwarps_max = 4;
67
  static constexpr bool Q_in_reg = true;
68
  static constexpr int nstages_target = 2;
69
- static constexpr int nbatch_K2 = 56;
70
- static constexpr int nbatch_V2 = 56;
71
- static constexpr int nbatch_combine = 56;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  };
73
 
74
  template <>
@@ -77,9 +161,30 @@ struct fattn_mma_f16_config<128, 128> {
77
  static constexpr int nwarps_max = 4;
78
  static constexpr bool Q_in_reg = true;
79
  static constexpr int nstages_target = 2;
80
- static constexpr int nbatch_K2 = 64;
81
- static constexpr int nbatch_V2 = 64;
82
- static constexpr int nbatch_combine = 64;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  };
84
 
85
  template <>
@@ -88,9 +193,38 @@ struct fattn_mma_f16_config<256, 256> {
88
  static constexpr int nwarps_max = 4;
89
  static constexpr bool Q_in_reg = true;
90
  static constexpr int nstages_target = 2;
91
- static constexpr int nbatch_K2 = 128;
92
- static constexpr int nbatch_V2 = 128;
93
- static constexpr int nbatch_combine = 128;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  };
95
 
96
  template <>
@@ -99,9 +233,44 @@ struct fattn_mma_f16_config<576, 512> {
99
  static constexpr int nwarps_max = 8;
100
  static constexpr bool Q_in_reg = false;
101
  static constexpr int nstages_target = 1;
102
- static constexpr int nbatch_K2 = 160;
103
- static constexpr int nbatch_V2 = 128;
104
- static constexpr int nbatch_combine = 128;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  };
106
 
107
  // ------------------------------------------------------------------------------------------------------------------
@@ -120,7 +289,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
120
 
121
  const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared(tile_KV);
122
 
123
- auto load = [&] __device__ (const int n) {
124
  const int stride_k = WARP_SIZE >> n;
125
  const int k0_start = stride_k == WARP_SIZE ? 0 : chunks_per_row - chunks_per_row % (2*stride_k);
126
  const int k0_stop = chunks_per_row - chunks_per_row % (1*stride_k);
@@ -223,7 +392,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
223
  }
224
  }
225
 
226
- template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool needs_fixup, bool is_fixup, bool last_iter>
227
  static __device__ __forceinline__ void flash_attn_ext_f16_iter(
228
  const float2 * const __restrict__ Q_f2,
229
  const half2 * const __restrict__ K_h2,
@@ -261,10 +430,15 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
261
  constexpr int cols_per_warp = ntiles * tile_B::I;
262
  constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles;
263
  constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
 
 
 
264
 
265
- constexpr int stride_tile_Q = DKQ/2 + 4;
266
- constexpr int stride_tile_K = c::nbatch_K2 + 4;
267
- constexpr int stride_tile_V = c::nbatch_V2 + 4;
 
 
268
 
269
  const int k_VKQ_0 = kb0 * c::nbatch_fa;
270
  tile_C_KQ KQ_C[c::nbatch_fa/(np*tile_C_KQ::I) * ntiles];
@@ -275,12 +449,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
275
  tile_C_KQ_16 * KQ_C_16 = (tile_C_KQ_16 *) KQ_C;
276
 
277
  if constexpr (nstages > 1) {
278
- static_assert(c::nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading");
 
279
  constexpr bool use_cp_async = true;
280
  cp_async_wait_all();
281
  __syncthreads();
282
  flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
283
- (V_h2 + k_VKQ_0*stride_V, tile_V, c::nbatch_V2, stride_V);
284
  } else {
285
  constexpr bool use_cp_async = nstages == 1;
286
  if (ncols2 > 1 || mask_h2) {
@@ -289,8 +464,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
289
  }
290
 
291
  #pragma unroll
292
- for (int k0_start = 0; k0_start < DKQ/2; k0_start += c::nbatch_K2) {
293
- const int k0_stop = k0_start + c::nbatch_K2 < DKQ/2 ? k0_start + c::nbatch_K2 : DKQ/2;
294
  const int k0_diff = k0_stop - k0_start;
295
 
296
  if (nstages <= 1) {
@@ -537,16 +712,21 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
537
  (mask_h2 + (k_VKQ_0 + c::nbatch_fa)/2, tile_mask, stride_mask);
538
  }
539
  flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
540
- (K_h2 + (k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, c::nbatch_K2, stride_K);
541
  }
542
  }
543
 
 
 
 
 
 
544
  #pragma unroll
545
- for (int i0_start = 0; i0_start < DV; i0_start += 2*c::nbatch_V2) {
546
- const int i0_stop = i0_start + 2*c::nbatch_V2 < DV ? i0_start + 2*c::nbatch_V2 : DV;
547
- const int i0_diff = i0_stop - i0_start;
548
 
549
- if (nstages <= 1) {
550
  constexpr bool use_cp_async = nstages == 1;
551
  flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
552
  (V_h2 + k_VKQ_0*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V);
@@ -555,6 +735,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
555
  }
556
  __syncthreads();
557
  }
 
558
 
559
  // Calculate VKQ tile:
560
  #pragma unroll
@@ -565,7 +746,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
565
  const int k0 = k00 + (threadIdx.y % np)*tile_A::J;
566
 
567
  tile_A A;
568
- load_ldmatrix_trans(A, tile_V + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
569
  if (ntiles == 1) {
570
  mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]);
571
  } else {
@@ -596,7 +777,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
596
  #endif // NEW_MMA_AVAILABLE
597
  }
598
 
599
- template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool needs_fixup, bool is_fixup>
600
  static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
601
  const float2 * const __restrict__ Q_f2,
602
  const half2 * const __restrict__ K_h2,
@@ -632,13 +813,16 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
632
  constexpr int cols_per_warp = ntiles * tile_B::I;
633
  constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles;
634
  constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
 
 
635
 
636
  static_assert(nwarps * (cols_per_warp/ncols2) % ncols1 == 0, "bad nwarps");
637
 
638
- constexpr int stride_tile_Q = DKQ/2 + 4;
639
- constexpr int stride_tile_K = c::nbatch_K2 + 4;
640
- constexpr int stride_tile_V = c::nbatch_V2 + 4;
641
 
 
 
642
  constexpr int stride_tile_KV_max = stride_tile_K > stride_tile_V ? stride_tile_K : stride_tile_V;
643
 
644
  extern __shared__ half2 tile_Q[];
@@ -726,26 +910,26 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
726
 
727
  // Preload mask and K data for first iteration when using cp_async with multiple stages:
728
  if constexpr (nstages > 1) {
729
- static_assert(c::nbatch_K2 == DKQ/2, "batching not implemented for multi-stage pipeline");
730
  constexpr bool use_cp_async = true;
731
  if (ncols2 > 1 || mask_h2) {
732
  flash_attn_ext_f16_load_mask<ncols1, nwarps, c::nbatch_fa, use_cp_async>
733
  (mask_h2 + kb0_start*c::nbatch_fa/2, tile_mask, stride_mask);
734
  }
735
  flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
736
- (K_h2 + kb0_start*c::nbatch_fa*stride_K, tile_K, c::nbatch_K2, stride_K);
737
  }
738
 
739
  // Iterate over ne11 == previous tokens:
740
  for (int kb0 = kb0_start; kb0 < kb0_stop-1; ++kb0) {
741
  constexpr bool last_iter = false;
742
- flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup, last_iter>
743
  (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
744
  ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
745
  }
746
  { // kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
747
  constexpr bool last_iter = true;
748
- flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup, last_iter>
749
  (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
750
  ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1);
751
  }
@@ -774,7 +958,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
774
  // It's also faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
775
  // So also write VKQ accumulators to shared memory in column-major format if np == 1.
776
 
777
- constexpr int nbatch_combine = c::Q_in_reg ? DV/2 : DV/4;
778
  constexpr int tile_stride = nbatch_combine + 4;
779
  static_assert((DV/2) % nbatch_combine == 0, "bad nbatch_combine");
780
 
@@ -1012,7 +1196,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
1012
  #endif // NEW_MMA_AVAILABLE
1013
  }
1014
 
1015
- template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap>
1016
  __launch_bounds__(nwarps*WARP_SIZE, 1)
1017
  static __global__ void flash_attn_ext_f16(
1018
  const char * __restrict__ Q,
@@ -1057,6 +1241,14 @@ static __global__ void flash_attn_ext_f16(
1057
  NO_DEVICE_CODE;
1058
  return;
1059
  }
 
 
 
 
 
 
 
 
1060
 
1061
  typedef fattn_mma_f16_config<DKQ, DV> c;
1062
 
@@ -1067,9 +1259,10 @@ static __global__ void flash_attn_ext_f16(
1067
  const int stride_Q1 = nb01 / sizeof(float2);
1068
  const int stride_Q2 = nb02 / sizeof(float2);
1069
  const int stride_K = nb11 / sizeof(half2);
1070
- const int stride_V = nb21 / sizeof(half2);
1071
  const int stride_mask = nb31 / sizeof(half2);
1072
 
 
 
1073
  const int iter_k = ne11 / FATTN_KQ_STRIDE;
1074
  const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
1075
 
@@ -1092,10 +1285,11 @@ static __global__ void flash_attn_ext_f16(
1092
 
1093
  const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
1094
  const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
1095
- const half2 * V_h2 = (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
1096
  const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
1097
  float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
1098
 
 
 
1099
  const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
1100
 
1101
  const int kb0_start_kernel = kb0_start * kb_niter;
@@ -1104,12 +1298,12 @@ static __global__ void flash_attn_ext_f16(
1104
  constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
1105
  if (kb0_start == 0) {
1106
  constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
1107
- flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup>
1108
  (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
1109
  ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
1110
  } else {
1111
  constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
1112
- flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup>
1113
  (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
1114
  ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
1115
  }
@@ -1130,10 +1324,11 @@ static __global__ void flash_attn_ext_f16(
1130
 
1131
  const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
1132
  const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
1133
- const half2 * V_h2 = (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio)); // K and V have same shape
1134
  const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
1135
  float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
1136
 
 
 
1137
  const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
1138
 
1139
  const int kb0_start_kernel = kb0_start * kb_niter;
@@ -1141,7 +1336,7 @@ static __global__ void flash_attn_ext_f16(
1141
 
1142
  constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
1143
  constexpr bool needs_fixup = false;
1144
- flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup>
1145
  (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
1146
  ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
1147
  #else
@@ -1167,10 +1362,6 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
1167
 
1168
  typedef fattn_mma_f16_config<DKQ, DV> c;
1169
 
1170
- constexpr int nbatch_K2 = c::nbatch_K2 < 1 ? DKQ/2 : c::nbatch_K2;
1171
- constexpr int nbatch_V2 = c::nbatch_V2 < 1 ? DV /2 : c::nbatch_V2;
1172
- constexpr int nbatch_combine = c::nbatch_combine < 1 ? DV /2 : c::nbatch_combine;
1173
-
1174
  const int nstages = cp_async_available(cc) ? c::nstages_target : 0;
1175
 
1176
  constexpr int ncols = ncols1 * ncols2;
@@ -1180,15 +1371,21 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
1180
  constexpr int nwarps_max_y = c::nbatch_fa / tile_A::I;
1181
  constexpr int nwarps = nwarps_max_x*nwarps_max_y <= c::nwarps_max ? nwarps_max_x*nwarps_max_y : c::nwarps_max;
1182
 
 
 
 
 
 
 
1183
  static_assert(DKQ % tile_B::J == 0, "bad DKQ");
1184
  static_assert(DV % tile_A::J == 0, "bad DV");
1185
  static_assert(ncols % cols_per_warp == 0, "bad ncols");
1186
 
1187
- const size_t nbytes_shared_KV_1stage = c::nbatch_fa * std::max(c::nbatch_K2 + 4, c::nbatch_V2 + 4) * sizeof(half2);
1188
- const size_t nbytes_shared_KV_2stage = c::nbatch_fa * (c::nbatch_K2 + 4 + c::nbatch_V2 + 4) * sizeof(half2);
1189
- const size_t nbytes_shared_Q = ncols * (DKQ/2 + 4) * sizeof(half2);
1190
- const size_t nbytes_shared_mask = ncols1 * (c::nbatch_fa/2 + 4) * sizeof(half2);
1191
- const size_t nbytes_shared_combine = nwarps*cols_per_warp * (nbatch_combine + 4) * sizeof(half2);
1192
 
1193
  const size_t nbytes_shared_KV = nstages <= 1 ? nbytes_shared_KV_1stage : nbytes_shared_KV_2stage;
1194
 
@@ -1202,7 +1399,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
1202
  fattn_kernel_t fattn_kernel;
1203
  if (logit_softcap == 0.0f) {
1204
  constexpr bool use_logit_softcap = false;
1205
- fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap>;
1206
 
1207
  #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
1208
  static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
@@ -1213,7 +1410,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
1213
  #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
1214
  } else {
1215
  constexpr bool use_logit_softcap = true;
1216
- fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap>;
1217
 
1218
  #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
1219
  static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
 
33
  static constexpr int nwarps_max = 4;
34
  static constexpr bool Q_in_reg = true;
35
  static constexpr int nstages_target = 2;
36
+
37
+ static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
38
+ return 32;
39
+ }
40
+
41
+ static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
42
+ return 32;
43
+ }
44
+
45
+ static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
46
+ return 32;
47
+ }
48
+
49
+ static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
50
+ return 32;
51
+ }
52
+
53
+ static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
54
+ return 32;
55
+ }
56
+
57
+ static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
58
+ return 32;
59
+ }
60
  };
61
 
62
  template <>
 
65
  static constexpr int nwarps_max = 4;
66
  static constexpr bool Q_in_reg = true;
67
  static constexpr int nstages_target = 2;
68
+
69
+ static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
70
+ return 40;
71
+ }
72
+
73
+ static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
74
+ return 40;
75
+ }
76
+
77
+ static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
78
+ return 40;
79
+ }
80
+
81
+ static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
82
+ return 40;
83
+ }
84
+
85
+ static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
86
+ return 40;
87
+ }
88
+
89
+ static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
90
+ return 40;
91
+ }
92
  };
93
 
94
  template <>
 
97
  static constexpr int nwarps_max = 4;
98
  static constexpr bool Q_in_reg = true;
99
  static constexpr int nstages_target = 2;
100
+
101
+ static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
102
+ return 48;
103
+ }
104
+
105
+ static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
106
+ return 48;
107
+ }
108
+
109
+ static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
110
+ return 48;
111
+ }
112
+
113
+ static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
114
+ return 48;
115
+ }
116
+
117
+ static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
118
+ return 48;
119
+ }
120
+
121
+ static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
122
+ return 48;
123
+ }
124
  };
125
 
126
  template <>
 
129
  static constexpr int nwarps_max = 4;
130
  static constexpr bool Q_in_reg = true;
131
  static constexpr int nstages_target = 2;
132
+
133
+ static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
134
+ return 56;
135
+ }
136
+
137
+ static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
138
+ return 56;
139
+ }
140
+
141
+ static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
142
+ return 56;
143
+ }
144
+
145
+ static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
146
+ return 56;
147
+ }
148
+
149
+ static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
150
+ return 56;
151
+ }
152
+
153
+ static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
154
+ return 56;
155
+ }
156
  };
157
 
158
  template <>
 
161
  static constexpr int nwarps_max = 4;
162
  static constexpr bool Q_in_reg = true;
163
  static constexpr int nstages_target = 2;
164
+
165
+ static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
166
+ return 64;
167
+ }
168
+
169
+ static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
170
+ return 64;
171
+ }
172
+
173
+ static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
174
+ return 64;
175
+ }
176
+
177
+ static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
178
+ return 64;
179
+ }
180
+
181
+ static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
182
+ return 64;
183
+ }
184
+
185
+ static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
186
+ return 64;
187
+ }
188
  };
189
 
190
  template <>
 
193
  static constexpr int nwarps_max = 4;
194
  static constexpr bool Q_in_reg = true;
195
  static constexpr int nstages_target = 2;
196
+
197
+ static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
198
+ return 128;
199
+ }
200
+
201
+ static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
202
+ return 128;
203
+ }
204
+
205
+ static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
206
+ return 128;
207
+ }
208
+
209
+ static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
210
+ return 128;
211
+ }
212
+
213
+ static int get_nbatch_combine_host(const int cc, const int ncols) {
214
+ if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) {
215
+ return ncols <= 16 ? 128 : 64;
216
+ }
217
+ return 64;
218
+ }
219
+
220
+ static constexpr __device__ int get_nbatch_combine_device(int ncols) {
221
+ #if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
222
+ return ncols <= 16 ? 128 : 64;
223
+ #else
224
+ GGML_UNUSED(ncols);
225
+ return 128;
226
+ #endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING
227
+ }
228
  };
229
 
230
  template <>
 
233
  static constexpr int nwarps_max = 8;
234
  static constexpr bool Q_in_reg = false;
235
  static constexpr int nstages_target = 1;
236
+
237
+ static int get_nbatch_K2_host(const int cc, const int ncols) {
238
+ if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) {
239
+ return ncols <= 16 ? 96 : 160;
240
+ }
241
+ return ncols <= 16 ? 288 : 160;
242
+ }
243
+
244
+ static constexpr __device__ int get_nbatch_K2_device(int ncols) {
245
+ #if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
246
+ return ncols <= 16 ? 96 : 160;
247
+ #else
248
+ return ncols <= 16 ? 288 : 160;
249
+ #endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING
250
+ }
251
+
252
+ static int get_nbatch_V2_host(const int cc, const int ncols) {
253
+ if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) {
254
+ return ncols <= 16 ? 64 : 128;
255
+ }
256
+ return ncols <= 16 ? 256 : 128;
257
+ }
258
+
259
+ static constexpr __device__ int get_nbatch_V2_device(int ncols) {
260
+ #if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
261
+ return ncols <= 16 ? 64 : 128;
262
+ #else
263
+ return ncols <= 16 ? 256 : 128;
264
+ #endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING
265
+ }
266
+
267
+ static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
268
+ return 128;
269
+ }
270
+
271
+ static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
272
+ return 128;
273
+ }
274
  };
275
 
276
  // ------------------------------------------------------------------------------------------------------------------
 
289
 
290
  const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared(tile_KV);
291
 
292
+ auto load = [&] __device__ (auto n) {
293
  const int stride_k = WARP_SIZE >> n;
294
  const int k0_start = stride_k == WARP_SIZE ? 0 : chunks_per_row - chunks_per_row % (2*stride_k);
295
  const int k0_stop = chunks_per_row - chunks_per_row % (1*stride_k);
 
392
  }
393
  }
394
 
395
+ template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup, bool last_iter>
396
  static __device__ __forceinline__ void flash_attn_ext_f16_iter(
397
  const float2 * const __restrict__ Q_f2,
398
  const half2 * const __restrict__ K_h2,
 
430
  constexpr int cols_per_warp = ntiles * tile_B::I;
431
  constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles;
432
  constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
433
+ constexpr int ncols = ncols1 * ncols2;
434
+ constexpr int nbatch_K2 = c::get_nbatch_K2_device(ncols);
435
+ constexpr int nbatch_V2 = c::get_nbatch_V2_device(ncols);
436
 
437
+ constexpr int stride_tile_Q = DKQ/2 + 4;
438
+ constexpr int stride_tile_K = nbatch_K2 + 4;
439
+
440
+ static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA");
441
+ constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4;
442
 
443
  const int k_VKQ_0 = kb0 * c::nbatch_fa;
444
  tile_C_KQ KQ_C[c::nbatch_fa/(np*tile_C_KQ::I) * ntiles];
 
449
  tile_C_KQ_16 * KQ_C_16 = (tile_C_KQ_16 *) KQ_C;
450
 
451
  if constexpr (nstages > 1) {
452
+ static_assert(!mla, "multi-stage loading not implemented for MLA");
453
+ static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading");
454
  constexpr bool use_cp_async = true;
455
  cp_async_wait_all();
456
  __syncthreads();
457
  flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
458
+ (V_h2 + k_VKQ_0*stride_V, tile_V, nbatch_V2, stride_V);
459
  } else {
460
  constexpr bool use_cp_async = nstages == 1;
461
  if (ncols2 > 1 || mask_h2) {
 
464
  }
465
 
466
  #pragma unroll
467
+ for (int k0_start = 0; k0_start < DKQ/2; k0_start += nbatch_K2) {
468
+ const int k0_stop = k0_start + nbatch_K2 < DKQ/2 ? k0_start + nbatch_K2 : DKQ/2;
469
  const int k0_diff = k0_stop - k0_start;
470
 
471
  if (nstages <= 1) {
 
712
  (mask_h2 + (k_VKQ_0 + c::nbatch_fa)/2, tile_mask, stride_mask);
713
  }
714
  flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
715
+ (K_h2 + (k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K);
716
  }
717
  }
718
 
719
+
720
+ // For MLA K and V have the same data.
721
+ // Therefore, iterate over V in reverse and re-use the data if possible.
722
+ static_assert(!mla || nstages <= 1, "combination of MLA and multi-stage loading not implemented");
723
+ constexpr int reusable_cutoff = mla ? (DKQ - 1) - (DKQ - 1) % (2*nbatch_K2) - (DKQ - DV) : DV;
724
  #pragma unroll
725
+ for (int i0_stop = DV; i0_stop > 0; i0_stop -= 2*nbatch_V2) {
726
+ const int i0_start = i0_stop - 2*nbatch_V2 > 0 ? i0_stop - 2*nbatch_V2 : 0;
727
+ const int i0_diff = i0_stop - i0_start;
728
 
729
+ if (nstages <= 1 && i0_start < reusable_cutoff) {
730
  constexpr bool use_cp_async = nstages == 1;
731
  flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
732
  (V_h2 + k_VKQ_0*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V);
 
735
  }
736
  __syncthreads();
737
  }
738
+ const half2 * tile_V_i = i0_start < reusable_cutoff ? tile_V : tile_V + (i0_start - reusable_cutoff)/2;
739
 
740
  // Calculate VKQ tile:
741
  #pragma unroll
 
746
  const int k0 = k00 + (threadIdx.y % np)*tile_A::J;
747
 
748
  tile_A A;
749
+ load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
750
  if (ntiles == 1) {
751
  mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]);
752
  } else {
 
777
  #endif // NEW_MMA_AVAILABLE
778
  }
779
 
780
+ template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup>
781
  static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
782
  const float2 * const __restrict__ Q_f2,
783
  const half2 * const __restrict__ K_h2,
 
813
  constexpr int cols_per_warp = ntiles * tile_B::I;
814
  constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles;
815
  constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
816
+ constexpr int nbatch_K2 = c::get_nbatch_K2_device(ncols);
817
+ constexpr int nbatch_V2 = c::get_nbatch_V2_device(ncols);
818
 
819
  static_assert(nwarps * (cols_per_warp/ncols2) % ncols1 == 0, "bad nwarps");
820
 
821
+ constexpr int stride_tile_Q = DKQ/2 + 4;
822
+ constexpr int stride_tile_K = nbatch_K2 + 4;
 
823
 
824
+ static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA");
825
+ constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4;
826
  constexpr int stride_tile_KV_max = stride_tile_K > stride_tile_V ? stride_tile_K : stride_tile_V;
827
 
828
  extern __shared__ half2 tile_Q[];
 
910
 
911
  // Preload mask and K data for first iteration when using cp_async with multiple stages:
912
  if constexpr (nstages > 1) {
913
+ static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi-stage pipeline");
914
  constexpr bool use_cp_async = true;
915
  if (ncols2 > 1 || mask_h2) {
916
  flash_attn_ext_f16_load_mask<ncols1, nwarps, c::nbatch_fa, use_cp_async>
917
  (mask_h2 + kb0_start*c::nbatch_fa/2, tile_mask, stride_mask);
918
  }
919
  flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
920
+ (K_h2 + kb0_start*c::nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K);
921
  }
922
 
923
  // Iterate over ne11 == previous tokens:
924
  for (int kb0 = kb0_start; kb0 < kb0_stop-1; ++kb0) {
925
  constexpr bool last_iter = false;
926
+ flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
927
  (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
928
  ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
929
  }
930
  { // kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
931
  constexpr bool last_iter = true;
932
+ flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
933
  (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
934
  ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1);
935
  }
 
958
  // It's also faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
959
  // So also write VKQ accumulators to shared memory in column-major format if np == 1.
960
 
961
+ constexpr int nbatch_combine = c::get_nbatch_combine_device(ncols);
962
  constexpr int tile_stride = nbatch_combine + 4;
963
  static_assert((DV/2) % nbatch_combine == 0, "bad nbatch_combine");
964
 
 
1196
  #endif // NEW_MMA_AVAILABLE
1197
  }
1198
 
1199
+ template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla>
1200
  __launch_bounds__(nwarps*WARP_SIZE, 1)
1201
  static __global__ void flash_attn_ext_f16(
1202
  const char * __restrict__ Q,
 
1241
  NO_DEVICE_CODE;
1242
  return;
1243
  }
1244
+ #if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
1245
+ if (ncols1*ncols2 > 32) {
1246
+ NO_DEVICE_CODE;
1247
+ return;
1248
+ }
1249
+ #endif __CUDA_ARCH__ == GGML_CUDA_CC_TURING
1250
+
1251
+ static_assert(!mla || DKQ >= DV, "MLA needs DKQ >= DV");
1252
 
1253
  typedef fattn_mma_f16_config<DKQ, DV> c;
1254
 
 
1259
  const int stride_Q1 = nb01 / sizeof(float2);
1260
  const int stride_Q2 = nb02 / sizeof(float2);
1261
  const int stride_K = nb11 / sizeof(half2);
 
1262
  const int stride_mask = nb31 / sizeof(half2);
1263
 
1264
+ const int stride_V = mla ? stride_K : nb21 / sizeof(half2);
1265
+
1266
  const int iter_k = ne11 / FATTN_KQ_STRIDE;
1267
  const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
1268
 
 
1285
 
1286
  const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
1287
  const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
 
1288
  const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
1289
  float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
1290
 
1291
+ const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
1292
+
1293
  const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
1294
 
1295
  const int kb0_start_kernel = kb0_start * kb_niter;
 
1298
  constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
1299
  if (kb0_start == 0) {
1300
  constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
1301
+ flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
1302
  (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
1303
  ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
1304
  } else {
1305
  constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
1306
+ flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
1307
  (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
1308
  ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
1309
  }
 
1324
 
1325
  const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
1326
  const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
 
1327
  const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
1328
  float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
1329
 
1330
+ const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
1331
+
1332
  const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
1333
 
1334
  const int kb0_start_kernel = kb0_start * kb_niter;
 
1336
 
1337
  constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
1338
  constexpr bool needs_fixup = false;
1339
+ flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
1340
  (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
1341
  ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
1342
  #else
 
1362
 
1363
  typedef fattn_mma_f16_config<DKQ, DV> c;
1364
 
 
 
 
 
1365
  const int nstages = cp_async_available(cc) ? c::nstages_target : 0;
1366
 
1367
  constexpr int ncols = ncols1 * ncols2;
 
1371
  constexpr int nwarps_max_y = c::nbatch_fa / tile_A::I;
1372
  constexpr int nwarps = nwarps_max_x*nwarps_max_y <= c::nwarps_max ? nwarps_max_x*nwarps_max_y : c::nwarps_max;
1373
 
1374
+ constexpr bool mla = DKQ == 576;
1375
+
1376
+ const int nbatch_K2 = c::get_nbatch_K2_host (cc, ncols);
1377
+ const int nbatch_V2 = c::get_nbatch_K2_host (cc, ncols);
1378
+ const int nbatch_combine = c::get_nbatch_combine_host(cc, ncols);
1379
+
1380
  static_assert(DKQ % tile_B::J == 0, "bad DKQ");
1381
  static_assert(DV % tile_A::J == 0, "bad DV");
1382
  static_assert(ncols % cols_per_warp == 0, "bad ncols");
1383
 
1384
+ const size_t nbytes_shared_KV_1stage = c::nbatch_fa * std::max(nbatch_K2 + 4, nbatch_V2 + 4) * sizeof(half2);
1385
+ const size_t nbytes_shared_KV_2stage = c::nbatch_fa * (nbatch_K2 + 4 + nbatch_V2 + 4) * sizeof(half2);
1386
+ const size_t nbytes_shared_Q = ncols * (DKQ/2 + 4) * sizeof(half2);
1387
+ const size_t nbytes_shared_mask = ncols1 * (c::nbatch_fa/2 + 4) * sizeof(half2);
1388
+ const size_t nbytes_shared_combine = nwarps*cols_per_warp * (nbatch_combine + 4) * sizeof(half2);
1389
 
1390
  const size_t nbytes_shared_KV = nstages <= 1 ? nbytes_shared_KV_1stage : nbytes_shared_KV_2stage;
1391
 
 
1399
  fattn_kernel_t fattn_kernel;
1400
  if (logit_softcap == 0.0f) {
1401
  constexpr bool use_logit_softcap = false;
1402
+ fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla>;
1403
 
1404
  #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
1405
  static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
 
1410
  #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
1411
  } else {
1412
  constexpr bool use_logit_softcap = true;
1413
+ fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla>;
1414
 
1415
  #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
1416
  static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
ggml/src/ggml-cuda/fattn.cu CHANGED
@@ -10,6 +10,7 @@
10
 
11
  template <int DKQ, int DV, int ncols2>
12
  static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
 
13
  const ggml_tensor * Q = dst->src[0];
14
 
15
  if constexpr (ncols2 <= 8) {
@@ -24,7 +25,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_con
24
  return;
25
  }
26
 
27
- if (Q->ne[1] <= 32/ncols2) {
28
  ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 32/ncols2, ncols2>(ctx, dst);
29
  return;
30
  }
 
10
 
11
  template <int DKQ, int DV, int ncols2>
12
  static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
13
+ const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
14
  const ggml_tensor * Q = dst->src[0];
15
 
16
  if constexpr (ncols2 <= 8) {
 
25
  return;
26
  }
27
 
28
+ if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING || Q->ne[1] <= 32/ncols2) {
29
  ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 32/ncols2, ncols2>(ctx, dst);
30
  return;
31
  }
ggml/src/ggml-cuda/ggml-cuda.cu CHANGED
@@ -3222,7 +3222,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3222
  #endif // FLASH_ATTN_AVAILABLE
3223
  if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
3224
  const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
3225
- if (!new_mma_available(cc) || cc < GGML_CUDA_CC_AMPERE) {
3226
  return false;
3227
  }
3228
  const int gqa_ratio = op->src[0]->ne[2] / op->src[1]->ne[2];
 
3222
  #endif // FLASH_ATTN_AVAILABLE
3223
  if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
3224
  const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
3225
+ if (!new_mma_available(cc)) {
3226
  return false;
3227
  }
3228
  const int gqa_ratio = op->src[0]->ne[2] / op->src[1]->ne[2];