JohannesGaessler commited on
Commit
507d30c
·
1 Parent(s): 2eca371

CUDA: FA support for Deepseek (Ampere or newer) (llama/13306)

Browse files

* CUDA: FA support for Deepseek (Ampere or newer)

* do loop unrolling via C++ template

Files changed (32) hide show
  1. ggml/src/ggml-cuda/CMakeLists.txt +1 -1
  2. ggml/src/ggml-cuda/common.cuh +19 -0
  3. ggml/src/ggml-cuda/cp-async.cuh +11 -0
  4. ggml/src/ggml-cuda/fattn-common.cuh +13 -13
  5. ggml/src/ggml-cuda/fattn-mma-f16.cuh +564 -344
  6. ggml/src/ggml-cuda/fattn-tile-f16.cu +2 -2
  7. ggml/src/ggml-cuda/fattn-tile-f32.cu +2 -2
  8. ggml/src/ggml-cuda/fattn-vec-f16.cuh +1 -1
  9. ggml/src/ggml-cuda/fattn-vec-f32.cuh +1 -1
  10. ggml/src/ggml-cuda/fattn-wmma-f16.cu +1 -1
  11. ggml/src/ggml-cuda/fattn.cu +71 -45
  12. ggml/src/ggml-cuda/ggml-cuda.cu +6 -6
  13. ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +5 -0
  14. ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +6 -6
  15. ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu +6 -6
  16. ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu +6 -6
  17. ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +6 -6
  18. ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +5 -0
  19. ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +6 -6
  20. ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +6 -6
  21. ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu +6 -6
  22. ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu +6 -6
  23. ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +5 -0
  24. ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu +6 -6
  25. ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +6 -6
  26. ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +6 -6
  27. ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu +6 -6
  28. ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu +6 -6
  29. ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu +6 -6
  30. ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +6 -6
  31. ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +6 -6
  32. ggml/src/ggml-cuda/template-instances/generate_cu_files.py +12 -9
ggml/src/ggml-cuda/CMakeLists.txt CHANGED
@@ -118,7 +118,7 @@ if (CUDAToolkit_FOUND)
118
 
119
  set(CUDA_CXX_FLAGS "")
120
 
121
- set(CUDA_FLAGS -use_fast_math)
122
 
123
  if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "12.8")
124
  # Options are:
 
118
 
119
  set(CUDA_CXX_FLAGS "")
120
 
121
+ set(CUDA_FLAGS -use_fast_math -extended-lambda)
122
 
123
  if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "12.8")
124
  # Options are:
ggml/src/ggml-cuda/common.cuh CHANGED
@@ -296,6 +296,25 @@ static __device__ void no_device_code(
296
  #define NO_DEVICE_CODE //GGML_ABORT("NO_DEVICE_CODE not valid in host code.")
297
  #endif // __CUDA_ARCH__
298
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
  template<int width = WARP_SIZE>
300
  static __device__ __forceinline__ int warp_reduce_sum(int x) {
301
  #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
 
296
  #define NO_DEVICE_CODE //GGML_ABORT("NO_DEVICE_CODE not valid in host code.")
297
  #endif // __CUDA_ARCH__
298
 
299
+ // The compiler is always able to unroll loops if they contain continue expressions.
300
+ // In such cases loop unrolling can still be achieved via recursion:
301
+ template <int n>
302
+ struct ggml_cuda_unroll {
303
+ template <typename Func, typename... Args>
304
+ __device__ void operator()(const Func & f, Args... args) const {
305
+ f(n - 1, args...);
306
+ ggml_cuda_unroll<n - 1>{}(f, args...);
307
+ }
308
+ };
309
+
310
+ template <>
311
+ struct ggml_cuda_unroll<1> {
312
+ template <typename Func, typename... Args>
313
+ __device__ void operator()(const Func & f, Args... args) const {
314
+ f(0, args...);
315
+ }
316
+ };
317
+
318
  template<int width = WARP_SIZE>
319
  static __device__ __forceinline__ int warp_reduce_sum(int x) {
320
  #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
ggml/src/ggml-cuda/cp-async.cuh CHANGED
@@ -2,6 +2,17 @@
2
 
3
  #include "common.cuh"
4
 
 
 
 
 
 
 
 
 
 
 
 
5
  // Copies data from global to shared memory, cg == cache global.
6
  // Both the src and dst pointers must be aligned to 16 bit.
7
  // Shared memory uses 32 bit addressing, the pointer is passed as unsigned int.
 
2
 
3
  #include "common.cuh"
4
 
5
+
6
+ static __device__ __forceinline__ unsigned int ggml_cuda_cvta_generic_to_shared(void * generic_ptr) {
7
+ #ifdef CP_ASYNC_AVAILABLE
8
+ return __cvta_generic_to_shared(generic_ptr);
9
+ #else
10
+ GGML_UNUSED(generic_ptr);
11
+ NO_DEVICE_CODE;
12
+ return 0;
13
+ #endif // CP_ASYNC_AVAILABLE
14
+ }
15
+
16
  // Copies data from global to shared memory, cg == cache global.
17
  // Both the src and dst pointers must be aligned to 16 bit.
18
  // Shared memory uses 32 bit addressing, the pointer is passed as unsigned int.
ggml/src/ggml-cuda/fattn-common.cuh CHANGED
@@ -516,7 +516,7 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
516
  nullptr;
517
  }
518
 
519
- template<int D, int ncols1, int ncols2, int KQ_stride> // D == head size
520
  __launch_bounds__(D, 1)
521
  static __global__ void flash_attn_stream_k_fixup(
522
  float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) {
@@ -665,13 +665,13 @@ static void on_no_fattn_vec_case(const int D) {
665
  fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for all combinations of q4_0, q4_1, q5_0, q5_1, q8_0, and f16.\n");
666
  GGML_ABORT("fatal error");
667
  } else {
668
- fprintf(stderr, "Unsupported KV type combination for head_size 256.\n");
669
  fprintf(stderr, "Only f16 is supported.\n");
670
  GGML_ABORT("fatal error");
671
  }
672
  }
673
 
674
- template <int D, int ncols1, int ncols2, int KQ_stride>
675
  void launch_fattn(
676
  ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared,
677
  const int KQ_row_granularity, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE
@@ -691,7 +691,7 @@ void launch_fattn(
691
 
692
  GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
693
  GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
694
- "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
695
 
696
  GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
697
 
@@ -754,10 +754,13 @@ void launch_fattn(
754
  const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];
755
 
756
  const dim3 block_dim(warp_size, nwarps, 1);
 
 
 
757
  dim3 blocks_num;
758
  if (stream_k) {
759
  // For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
760
- const int max_blocks = 2*nsm;
761
  const int tiles_nwaves = (ntiles_total + max_blocks - 1) / max_blocks;
762
  const int tiles_efficiency_percent = 100 * ntiles_total / (max_blocks*tiles_nwaves);
763
 
@@ -769,14 +772,11 @@ void launch_fattn(
769
  blocks_num.y = 1;
770
  blocks_num.z = 1;
771
 
772
- dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + D) * sizeof(float));
773
  } else {
774
  GGML_ASSERT(K->ne[1] % KQ_row_granularity == 0);
775
  const int ntiles_KQ = K->ne[1] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size.
776
 
777
- int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy.
778
- CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared));
779
-
780
  // parallel_blocks should be at least large enough to achieve max. occupancy for a single wave:
781
  parallel_blocks = std::max((nsm * max_blocks_per_sm) / ntiles_total, 1);
782
 
@@ -853,19 +853,19 @@ void launch_fattn(
853
 
854
  if (stream_k) {
855
  if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
856
- const dim3 block_dim_combine(D, 1, 1);
857
  const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2};
858
 
859
- flash_attn_stream_k_fixup<D, ncols1, ncols2, KQ_stride>
860
  <<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
861
  ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
862
  }
863
  } else if (parallel_blocks > 1) {
864
- const dim3 block_dim_combine(D, 1, 1);
865
  const dim3 blocks_num_combine(Q->ne[1], 1, blocks_num.z);
866
  const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2);
867
 
868
- flash_attn_combine_results<D>
869
  <<<blocks_num_combine, block_dim_combine, nbytes_shared_combine, main_stream>>>
870
  (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data, parallel_blocks);
871
  }
 
516
  nullptr;
517
  }
518
 
519
+ template<int D, int ncols1, int ncols2> // D == head size
520
  __launch_bounds__(D, 1)
521
  static __global__ void flash_attn_stream_k_fixup(
522
  float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) {
 
665
  fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for all combinations of q4_0, q4_1, q5_0, q5_1, q8_0, and f16.\n");
666
  GGML_ABORT("fatal error");
667
  } else {
668
+ fprintf(stderr, "Unsupported KV type combination for head_size %d.\n", D);
669
  fprintf(stderr, "Only f16 is supported.\n");
670
  GGML_ABORT("fatal error");
671
  }
672
  }
673
 
674
+ template <int DV, int ncols1, int ncols2>
675
  void launch_fattn(
676
  ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared,
677
  const int KQ_row_granularity, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE
 
691
 
692
  GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
693
  GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
694
+ "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
695
 
696
  GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
697
 
 
754
  const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];
755
 
756
  const dim3 block_dim(warp_size, nwarps, 1);
757
+ int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy.
758
+ CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared));
759
+
760
  dim3 blocks_num;
761
  if (stream_k) {
762
  // For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
763
+ const int max_blocks = max_blocks_per_sm*nsm;
764
  const int tiles_nwaves = (ntiles_total + max_blocks - 1) / max_blocks;
765
  const int tiles_efficiency_percent = 100 * ntiles_total / (max_blocks*tiles_nwaves);
766
 
 
772
  blocks_num.y = 1;
773
  blocks_num.z = 1;
774
 
775
+ dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + DV) * sizeof(float));
776
  } else {
777
  GGML_ASSERT(K->ne[1] % KQ_row_granularity == 0);
778
  const int ntiles_KQ = K->ne[1] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size.
779
 
 
 
 
780
  // parallel_blocks should be at least large enough to achieve max. occupancy for a single wave:
781
  parallel_blocks = std::max((nsm * max_blocks_per_sm) / ntiles_total, 1);
782
 
 
853
 
854
  if (stream_k) {
855
  if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
856
+ const dim3 block_dim_combine(DV, 1, 1);
857
  const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2};
858
 
859
+ flash_attn_stream_k_fixup<DV, ncols1, ncols2>
860
  <<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
861
  ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
862
  }
863
  } else if (parallel_blocks > 1) {
864
+ const dim3 block_dim_combine(DV, 1, 1);
865
  const dim3 blocks_num_combine(Q->ne[1], 1, blocks_num.z);
866
  const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2);
867
 
868
+ flash_attn_combine_results<DV>
869
  <<<blocks_num_combine, block_dim_combine, nbytes_shared_combine, main_stream>>>
870
  (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data, parallel_blocks);
871
  }
ggml/src/ggml-cuda/fattn-mma-f16.cuh CHANGED
@@ -13,104 +13,217 @@ typedef tile<16, 16, float> tile_C_KQ_16;
13
  typedef tile<16, 4, half2> tile_C_VKQ;
14
  typedef tile<16, 8, half2> tile_C_VKQ_16;
15
 
16
- template<int D, int nwarps, int KQ_per_iter>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
18
- const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int stride_KV) {
19
- constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts.
20
 
21
- // If cp.async is available, load up to the highest power of 2 in D asynchronously:
22
- #ifdef CP_ASYNC_AVAILABLE
23
- static_assert(D >= 64 && D < 512, "bad D");
24
- constexpr int k0_sync_start = D/2 < 64 ? 32 : (D/2 < 128 ? 64 : 128);
 
 
 
25
 
26
- const unsigned int tile_KV_32 = __cvta_generic_to_shared(tile_KV);
 
 
 
 
 
 
 
 
 
 
27
 
28
- constexpr int preload = 64;
29
- constexpr int h2_per_chunk = 16/sizeof(half2);
30
- constexpr int chunks_per_row = k0_sync_start / h2_per_chunk;
31
- constexpr int stride_i = WARP_SIZE / chunks_per_row;
32
  #pragma unroll
33
- for (int i0 = 0; i0 < KQ_per_iter; i0 += nwarps*stride_i) {
34
- const int i = i0 + threadIdx.y*stride_i + (chunks_per_row == WARP_SIZE ? 0 : threadIdx.x / chunks_per_row);
35
- const int k = (chunks_per_row == WARP_SIZE ? threadIdx.x : threadIdx.x % chunks_per_row)*h2_per_chunk;
36
 
37
- cp_async_cg_16<preload>(tile_KV_32 + (i*D2_padded + k)*sizeof(half2), KV + i*stride_KV + k);
38
- }
39
- #else
40
- constexpr int k0_sync_start = 0;
41
- #endif // CP_ASYNC_AVAILABLE
42
- static_assert(k0_sync_start % WARP_SIZE == 0, "bad k0_sync_start");
43
 
44
- // If D is not a power of 2, the rest is loaded synchronously.
45
- // K/V data is loaded with decreasing granularity for D for better memory bandwidth.
46
- static_assert(KQ_per_iter % (4*nwarps) == 0, "out of bounds");
47
  #pragma unroll
48
- for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
49
- const int k0_start = stride_k == WARP_SIZE ? k0_sync_start : D/2 - (D/2) % (2*stride_k);
50
- const int k0_stop = D/2 - (D/2) % (1*stride_k);
51
- const int stride_i = WARP_SIZE / stride_k;
52
 
53
- if (k0_start == k0_stop || k0_stop <= k0_sync_start) {
54
- continue;
55
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
  #pragma unroll
58
- for (int i0 = 0; i0 < KQ_per_iter; i0 += nwarps*stride_i) {
59
- const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
 
 
 
 
60
 
61
  #pragma unroll
62
- for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
63
- const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
64
 
65
- tile_KV[i*D2_padded + k] = KV[i*stride_KV + k];
 
66
  }
67
- }
 
68
  }
69
  }
70
 
71
- template<int ncols1, int nwarps, int KQ_per_iter>
72
  static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
73
  const half2 * const __restrict__ mask_h2, half2 * const __restrict__ tile_mask, const int stride_mask) {
74
- static_assert(KQ_per_iter == 2*WARP_SIZE || KQ_per_iter == WARP_SIZE, "bad KQ_per_iter");
75
- #ifdef CP_ASYNC_AVAILABLE
76
- constexpr int preload = KQ_per_iter * sizeof(half);
77
- constexpr int cols_per_warp = 8*WARP_SIZE/KQ_per_iter;
78
- constexpr int stride_j = nwarps * cols_per_warp;
 
79
 
80
- const unsigned int tile_mask_32 = __cvta_generic_to_shared(tile_mask);
81
 
82
  #pragma unroll
83
- for (int j0 = 0; j0 < ncols1; j0 += stride_j) {
84
- const int j = j0 + threadIdx.y*cols_per_warp +
85
- (KQ_per_iter == 2*WARP_SIZE ? threadIdx.x / (WARP_SIZE/4) : threadIdx.x / (WARP_SIZE/8));
86
 
87
- if (j0 + stride_j > ncols1 && j >= ncols1) {
88
- break;
89
- }
90
 
91
- const int i = 4 * (KQ_per_iter == 2*WARP_SIZE ? threadIdx.x % (WARP_SIZE/4) : threadIdx.x % (WARP_SIZE/8));
92
 
93
- cp_async_cg_16<preload>(tile_mask_32 + j*(KQ_per_iter*sizeof(half) + 16) + i*sizeof(half2), mask_h2 + j*stride_mask + i);
 
 
94
  }
95
- #else
96
- constexpr int cols_per_warp = 2*WARP_SIZE/KQ_per_iter;
97
  constexpr int stride_j = nwarps * cols_per_warp;
98
  #pragma unroll
99
  for (int j0 = 0; j0 < ncols1; j0 += stride_j) {
100
- const int j = j0 + threadIdx.y*cols_per_warp + (KQ_per_iter == 2*WARP_SIZE ? 0 : threadIdx.x / (WARP_SIZE/2));
101
 
102
  if (j0 + stride_j > ncols1 && j >= ncols1) {
103
  break;
104
  }
105
 
106
- const int i = KQ_per_iter == 2*WARP_SIZE ? threadIdx.x : threadIdx.x % (WARP_SIZE/2);
107
 
108
- tile_mask[j*(KQ_per_iter/2 + 4) + i] = mask_h2[j*stride_mask + i];
109
  }
110
- #endif // CP_ASYNC_AVAILABLE
111
  }
112
 
113
- template<int D, int ncols1, int ncols2, int nwarps, int KQ_per_iter, int ntiles, bool use_logit_softcap, bool needs_fixup, bool is_fixup, bool last_iter>
114
  static __device__ __forceinline__ void flash_attn_ext_f16_iter(
115
  const float2 * const __restrict__ Q_f2,
116
  const half2 * const __restrict__ K_h2,
@@ -123,9 +236,11 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
123
  const float logit_softcap,
124
  const int ne01,
125
  const int ne02,
126
- const int stride_KV,
 
127
  const int stride_mask,
128
  const int jt,
 
129
  half2 * const __restrict__ tile_K,
130
  half2 * const __restrict__ tile_V,
131
  half2 * const __restrict__ tile_mask,
@@ -135,59 +250,107 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
135
  float * const __restrict__ KQ_rowsum,
136
  const int kb0) {
137
  #ifdef NEW_MMA_AVAILABLE
 
 
 
 
 
 
 
 
138
  constexpr int cols_per_warp = ntiles * tile_B::I;
139
  constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles;
140
  constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
141
- constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts.
142
 
143
- const int k_VKQ_0 = kb0 * KQ_per_iter;
144
- tile_C_KQ KQ_C[KQ_per_iter/(np*tile_C_KQ::I) * ntiles];
 
 
 
 
145
 
146
  // Use wide variants of tiles if ntiles >= 2.
147
  tile_B_16 * Q_B_16 = (tile_B_16 *) Q_B;
148
  tile_C_VKQ_16 * VKQ_C_16 = (tile_C_VKQ_16 *) VKQ_C;
149
  tile_C_KQ_16 * KQ_C_16 = (tile_C_KQ_16 *) KQ_C;
150
 
151
- #ifdef CP_ASYNC_AVAILABLE
152
- cp_async_wait_all();
153
- __syncthreads();
154
- flash_attn_ext_f16_load_tile<D, nwarps, KQ_per_iter>(V_h2 + k_VKQ_0*stride_KV, tile_V, stride_KV);
155
- #else
156
- if (ncols2 > 1 || mask_h2) {
157
- flash_attn_ext_f16_load_mask<ncols1, nwarps, KQ_per_iter>(mask_h2 + k_VKQ_0/2, tile_mask, stride_mask);
 
 
 
 
 
158
  }
159
- flash_attn_ext_f16_load_tile<D, nwarps, KQ_per_iter>(K_h2 + k_VKQ_0*stride_KV, tile_K, stride_KV);
160
- __syncthreads();
161
- #endif // CP_ASYNC_AVAILABLE
162
 
163
- // Calculate tile of KQ:
164
  #pragma unroll
165
- for (int i_KQ_00 = 0; i_KQ_00 < KQ_per_iter; i_KQ_00 += np*tile_A::I) {
166
- const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*tile_A::I;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  #pragma unroll
168
- for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += tile_A::J) {
169
- tile_A K_A;
170
- load_ldmatrix(K_A, tile_K + i_KQ_0*D2_padded + k_KQ_0, D2_padded);
171
- if (ntiles == 1) {
172
- mma(KQ_C[i_KQ_00/(np*tile_A::I)], K_A, Q_B[k_KQ_0/tile_A::J]);
173
- } else {
174
  #pragma unroll
175
- for (int t = 0; t < ntiles/2; ++t) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  // Wide version of KQ_C is column-major => swap A and B.
177
- mma(KQ_C_16[i_KQ_00/(np*tile_A::I) * ntiles/2 + t], Q_B_16[k_KQ_0/tile_A::J * ntiles/2 + t], K_A);
178
  }
179
  }
180
  }
181
- }
182
 
183
- #ifndef CP_ASYNC_AVAILABLE
184
- __syncthreads(); // Only needed if tile_K == tile_V.
185
- #endif // CP_ASYNC_AVAILABLE
 
186
 
187
  if (use_logit_softcap) {
188
- static_assert(KQ_per_iter % (np*tile_C_KQ::I) == 0, "bad loop size");
189
  #pragma unroll
190
- for (int i = 0; i < KQ_per_iter/(np*tile_C_KQ::I) * ntiles; ++i) {
191
  #pragma unroll
192
  for (int l = 0; l < tile_C_KQ::ne; ++l) {
193
  KQ_C[i].x[l] = logit_softcap*tanhf(KQ_C[i].x[l]);
@@ -205,7 +368,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
205
  if (ntiles == 1) {
206
  if (ncols2 > 1 || mask_h2) {
207
  #pragma unroll
208
- for (int i00 = 0; i00 < KQ_per_iter; i00 += np*tile_C_KQ::I) {
209
  const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ::I;
210
  #pragma unroll
211
  for (int l = 0; l < tile_C_KQ::ne; ++l) {
@@ -213,16 +376,16 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
213
  const int j = ((threadIdx.y / np)*tile_C_KQ::J + tile_C_KQ::get_j(l)) / ncols2;
214
 
215
  KQ_C[i00/(np*tile_C_KQ::I)].x[l] += slope *
216
- __half2float(((const half *) tile_mask)[j*(KQ_per_iter + 8) + i]);
217
  }
218
  }
219
  }
220
 
221
  // Calculate softmax for each KQ column using the current max. value.
222
  // The divisor is stored in KQ_rowsum and will be applied at the end.
223
- static_assert(KQ_per_iter % (np*tile_C_KQ::I) == 0, "bad loop size");
224
  #pragma unroll
225
- for (int k = 0; k < KQ_per_iter/(np*tile_C_KQ::I); ++k) {
226
  #pragma unroll
227
  for (int l = 0; l < tile_C_KQ::ne; ++l) {
228
  KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k].x[l]);
@@ -238,10 +401,9 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
238
  }
239
  }
240
 
241
- static_assert(KQ_per_iter % (np*tile_C_KQ::I) == 0, "bad loop size");
242
-
243
  #pragma unroll
244
- for (int k = 0; k < KQ_per_iter/(np*tile_C_KQ::I); ++k) {
245
  #pragma unroll
246
  for (int l = 0; l < tile_C_KQ::ne; ++l) {
247
  KQ_C[k].x[l] = expf(KQ_C[k].x[l] - KQ_max_new[l % 2]);
@@ -252,7 +414,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
252
  } else { // ntiles > 1
253
  if (ncols2 > 1 || mask_h2) {
254
  #pragma unroll
255
- for (int i00 = 0; i00 < KQ_per_iter; i00 += np*tile_C_KQ_16::J) {
256
  const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ_16::J;
257
  #pragma unroll
258
  for (int t = 0; t < ntiles/2; ++t) {
@@ -261,7 +423,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
261
  const int i = (i0 + tile_C_KQ_16::get_j(l0)) / 2;
262
  const int j = ((threadIdx.y / np)*cols_per_warp + t*tile_C_KQ_16::I + tile_C_KQ_16::get_i(l0)) / ncols2;
263
 
264
- const float2 tmp = __half22float2(tile_mask[j*(KQ_per_iter/2 + 4) + i]);
265
  const int KQ_index = i00/(np*tile_C_KQ_16::J) * ntiles/2 + t;
266
  KQ_C_16[KQ_index].x[l0 + 0] += slope*tmp.x;
267
  KQ_C_16[KQ_index].x[l0 + 1] += slope*tmp.y;
@@ -272,9 +434,9 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
272
 
273
  // Calculate softmax for each KQ column using the current max. value.
274
  // The divisor is stored in KQ_rowsum and will be applied at the end.
275
- static_assert(KQ_per_iter % (np*tile_C_KQ::I) == 0, "bad loop size");
276
  #pragma unroll
277
- for (int k = 0; k < KQ_per_iter/(np*tile_C_KQ_16::J); ++k) {
278
  #pragma unroll
279
  for (int t = 0; t < ntiles/2; ++t) {
280
  #pragma unroll
@@ -294,9 +456,9 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
294
  }
295
  }
296
 
297
- static_assert(KQ_per_iter % (np*tile_C_KQ_16::J) == 0, "bad loop size");
298
  #pragma unroll
299
- for (int k = 0; k < KQ_per_iter/(np*tile_C_KQ_16::J); ++k) {
300
  #pragma unroll
301
  for (int t = 0; t < ntiles/2; ++t) {
302
  #pragma unroll
@@ -325,7 +487,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
325
  if (ntiles == 1) {
326
  const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]);
327
  #pragma unroll
328
- for (int i = 0; i < D/tile_C_VKQ::I; ++i) {
329
  #pragma unroll
330
  for (int l = 0; l < tile_C_VKQ::ne; ++l) {
331
  VKQ_C[i].x[l] *= KQ_max_scale_h2;
@@ -336,7 +498,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
336
  for (int col = 0; col < cols_per_thread; ++col) {
337
  const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]);
338
  #pragma unroll
339
- for (int i = 0; i < D/tile_C_VKQ_16::J; ++i) {
340
  #pragma unroll
341
  for (int l0 = 0; l0 < tile_C_VKQ_16::ne; l0 += 2) {
342
  VKQ_C_16[i*ntiles/2 + col/2].x[l0 + col % 2] *= KQ_max_scale_h2;
@@ -347,16 +509,16 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
347
  }
348
 
349
  // Convert KQ C tiles into B tiles for VKQ calculation:
350
- tile_B B[KQ_per_iter/(np*2*tile_B::J) * ntiles];
351
  tile_B_16 * B_16 = (tile_B_16 *) B;
352
- static_assert(KQ_per_iter % (np*2*tile_B::J) == 0, "bad loop size");
353
  if (ntiles == 1) {
354
  #pragma unroll
355
- for (int k = 0; k < KQ_per_iter/(np*2*tile_B::J); ++k) {
356
  B[k] = get_transposed(get_half2(KQ_C[k]));
357
  }
358
  } else {
359
- for (int k = 0; k < KQ_per_iter/(np*2*tile_B_16::J); ++k) {
360
  #pragma unroll
361
  for (int t = 0; t < ntiles/2; ++t) {
362
  B_16[k*ntiles/2 + t] = get_half2(KQ_C_16[k*ntiles/2 + t]);
@@ -364,52 +526,67 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
364
  }
365
  }
366
 
367
- #ifdef CP_ASYNC_AVAILABLE
368
- // Preload K tile for next iteration:
369
- cp_async_wait_all();
370
- __syncthreads();
371
- if (!last_iter) {
372
- if (ncols2 > 1 || mask_h2) {
373
- flash_attn_ext_f16_load_mask<ncols1, nwarps, KQ_per_iter>(mask_h2 + (k_VKQ_0 + KQ_per_iter)/2, tile_mask, stride_mask);
 
 
 
 
 
374
  }
375
- flash_attn_ext_f16_load_tile<D, nwarps, KQ_per_iter>(K_h2 + (k_VKQ_0 + KQ_per_iter)*stride_KV, tile_K, stride_KV);
376
  }
377
- #else
378
- flash_attn_ext_f16_load_tile<D, nwarps, KQ_per_iter>(V_h2 + k_VKQ_0*stride_KV, tile_V, stride_KV);
379
- __syncthreads();
380
- #endif // CP_ASYNC_AVAILABLE
381
 
382
- // Calculate VKQ tile:
383
  #pragma unroll
384
- for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += tile_C_VKQ::I) {
385
- static_assert((KQ_per_iter/2) % (np*tile_A::J) == 0, "bad loop size");
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
386
  #pragma unroll
387
- for (int k00 = 0; k00 < KQ_per_iter/2; k00 += np*tile_A::J) {
388
- const int k0 = k00 + (threadIdx.y % np)*tile_A::J;
389
 
390
- tile_A A;
391
- load_ldmatrix_trans(A, tile_V + 2*k0*D2_padded + i_VKQ_0/2, D2_padded);
392
- if (ntiles == 1) {
393
- mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]);
394
- } else {
395
  #pragma unroll
396
- for (int t = 0; t < ntiles/2; ++t) {
397
- // Wide version of VKQ_C is column-major => swap A and B.
398
- mma(VKQ_C_16[i_VKQ_0/tile_C_VKQ::I * ntiles/2 + t], B_16[k00/(np*tile_A::J) * ntiles/2 + t], A);
 
399
  }
400
  }
401
  }
402
- }
403
-
404
- #ifndef CP_ASYNC_AVAILABLE
405
- __syncthreads(); // Only needed if tile_K == tile_V.
406
- #endif // CP_ASYNC_AVAILABLE
407
 
 
 
 
 
408
  #else
409
  GGML_UNUSED(Q_f2); GGML_UNUSED(K_h2); GGML_UNUSED(V_h2);
410
  GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup);
411
  GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap);
412
- GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_KV);
413
  GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K);
414
  GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K);
415
  GGML_UNUSED(tile_V); GGML_UNUSED(tile_mask); GGML_UNUSED(Q_B);
@@ -419,7 +596,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
419
  #endif // NEW_MMA_AVAILABLE
420
  }
421
 
422
- template<int D, int ncols1, int ncols2, int nwarps, int KQ_per_iter, int ntiles, bool use_logit_softcap, bool needs_fixup, bool is_fixup>
423
  static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
424
  const float2 * const __restrict__ Q_f2,
425
  const half2 * const __restrict__ K_h2,
@@ -434,7 +611,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
434
  const int ne02,
435
  const int stride_Q1,
436
  const int stride_Q2,
437
- const int stride_KV,
 
438
  const int stride_mask,
439
  const int jt,
440
  const int kb0_start,
@@ -442,6 +620,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
442
  #ifdef NEW_MMA_AVAILABLE
443
  //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
444
 
 
 
 
 
 
 
 
 
445
  constexpr int ncols = ncols1 * ncols2;
446
  constexpr int cols_per_warp = ntiles * tile_B::I;
447
  constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles;
@@ -449,22 +635,19 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
449
 
450
  static_assert(nwarps * (cols_per_warp/ncols2) % ncols1 == 0, "bad nwarps");
451
 
452
- static_assert(D % nwarps == 0, "bad D");
453
- static_assert(KQ_per_iter % nwarps == 0, "bad KQ_per_iter");
 
454
 
455
- constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts.
456
 
457
- // Temporary shared buffer for loading K/V data with KQ_per_iter*D logical elements:
458
- extern __shared__ half2 tile_K[];
459
- #ifdef CP_ASYNC_AVAILABLE
460
- half2 * tile_V = tile_K + KQ_per_iter*D2_padded;
461
- #else
462
- half2 * tile_V = tile_K;
463
- #endif // CP_ASYNC_AVAILABLE
464
- half2 * tile_mask = tile_V + KQ_per_iter*D2_padded;
465
 
466
- tile_B Q_B[D/(2*tile_B::J) * ntiles];
467
- tile_C_VKQ VKQ_C[D/tile_C_VKQ::I * ntiles];
468
 
469
  tile_B_16 * Q_B_16 = (tile_B_16 *) Q_B;
470
  tile_C_VKQ_16 * VKQ_C_16 = (tile_C_VKQ_16 *) VKQ_C;
@@ -476,13 +659,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
476
  KQ_max[col] = -FLT_MAX/2.0f;
477
  }
478
 
479
- // Temporarily load Q data into tile_K, will be loaded into registers afterwards.
 
480
  // The loading is done with decreasing granularity for D for better memory bandwidth.
481
  const half2 scale_h2 = make_half2(scale, scale);
482
  #pragma unroll
483
  for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
484
- const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k);
485
- const int k0_stop = D/2 - (D/2) % (1*stride_k);
486
  const int stride_jc = WARP_SIZE / stride_k;
487
 
488
  if (k0_start == k0_stop) {
@@ -506,14 +690,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
506
  const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
507
 
508
  const float2 tmp = Q_f2[(jt*ncols1 + j)*stride_Q1 + c*stride_Q2 + k];
509
- tile_K[jc*D2_padded + k] = scale_h2 * make_half2(tmp.x, tmp.y);
510
  }
511
  } else {
512
  #pragma unroll
513
  for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
514
  const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
515
 
516
- tile_K[jc*D2_padded + k] = make_half2(0.0f, 0.0f);
517
  }
518
  }
519
  }
@@ -521,18 +705,18 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
521
 
522
  __syncthreads();
523
 
524
- {
525
  const int j0 = (threadIdx.y / np) * cols_per_warp;
526
 
527
  #pragma unroll
528
- for (int k0 = 0; k0 < D/2; k0 += tile_B::J) {
529
  if (ntiles == 1) {
530
- load_ldmatrix(Q_B[k0/tile_B::J], tile_K + j0*D2_padded + k0, D2_padded);
531
  } else {
532
  #pragma unroll
533
  for (int t = 0; t < ntiles/2; ++t) {
534
  load_ldmatrix(Q_B_16[k0/tile_B_16::J * ntiles/2 + t],
535
- tile_K + (j0 + t*tile_B_16::I)*D2_padded + k0, D2_padded);
536
  }
537
  }
538
  }
@@ -540,35 +724,37 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
540
 
541
  __syncthreads();
542
 
543
- // Preload mask and K data for first iteration when using cp_async:
544
- #ifdef CP_ASYNC_AVAILABLE
545
- if (ncols2 > 1 || mask_h2) {
546
- flash_attn_ext_f16_load_mask<ncols1, nwarps, KQ_per_iter>(mask_h2 + kb0_start*KQ_per_iter/2, tile_mask, stride_mask);
 
 
 
 
 
 
547
  }
548
- flash_attn_ext_f16_load_tile<D, nwarps, KQ_per_iter>(K_h2 + kb0_start*KQ_per_iter*stride_KV, tile_K, stride_KV);
549
- #endif // CP_ASYNC_AVAILABLE
550
 
551
  // Iterate over ne11 == previous tokens:
552
  for (int kb0 = kb0_start; kb0 < kb0_stop-1; ++kb0) {
553
  constexpr bool last_iter = false;
554
- flash_attn_ext_f16_iter<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup, last_iter>
555
  (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
556
- ne01, ne02, stride_KV, stride_mask, jt, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
557
  }
558
  { // kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
559
  constexpr bool last_iter = true;
560
- flash_attn_ext_f16_iter<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup, last_iter>
561
  (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
562
- ne01, ne02, stride_KV, stride_mask, jt, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1);
563
  }
564
 
565
- // With cp_async there is no __syncthreads at the end of the iter,
566
  // there can be a race condition on shared memory access for combining/writing back results.
567
- #ifdef CP_ASYNC_AVAILABLE
568
- if (nwarps*cols_per_warp > KQ_per_iter) {
569
  __syncthreads();
570
  }
571
- #endif // CP_ASYNC_AVAILABLE
572
 
573
  // Finally, sum up partial KQ rowsums.
574
  // The partial sums are spread across 8/4 threads each, does not need full reduce.
@@ -584,38 +770,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
584
  }
585
  }
586
 
587
- // Write VKQ accumulators to shared memory in column-major format.
588
- // It's faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
589
- // Also for np > 1 the combination is done via these values in shared memory.
590
- if (ntiles == 1) {
591
- const int jc_cwd = threadIdx.y*tile_B::I + tile_B::get_i(-1); // jc combine write data
592
- #pragma unroll
593
- for (int k0 = 0; k0 < D/2; k0 += tile_B::J) {
594
- const tile_B B = get_transposed(VKQ_C[k0/tile_B::J]); // Conversion of C to B matrix puts it in column-major format.
595
 
596
- #pragma unroll
597
- for (int l = 0; l < tile_B::ne; ++l) {
598
- const int k = k0 + tile_B::get_j(l);
599
-
600
- tile_K[jc_cwd*D2_padded + k] = B.x[l];
601
- }
602
- }
603
- } else {
604
- #pragma unroll
605
- for (int t = 0; t < ntiles/2; ++t) {
606
- const int j0 = threadIdx.y*cols_per_warp + t*tile_C_VKQ_16::I;
607
- #pragma unroll
608
- for (int k0 = 0; k0 < D/2; k0 += tile_C_VKQ_16::J) {
609
- #pragma unroll
610
- for (int l = 0; l < tile_C_VKQ_16::ne; ++l) {
611
- const int j = j0 + tile_C_VKQ_16::get_i(l);
612
- const int k = k0 + tile_C_VKQ_16::get_j(l);
613
-
614
- tile_K[j*D2_padded + k] = VKQ_C_16[k0/tile_C_VKQ_16::J * ntiles/2 + t].x[l];
615
- }
616
- }
617
- }
618
- }
619
 
620
  if constexpr (ntiles == 1) {
621
  const int jc_cwmo = (threadIdx.x % (2*tile_C_VKQ::J)) / tile_C_VKQ::J; // jc combine write meta offset
@@ -624,7 +785,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
624
 
625
  if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*tile_C_VKQ::J) {
626
  // Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale.
627
- ((float2 *) tile_K)[jc_cwm*(D2_padded/2) + D/4] = KQ_cmr;
628
  }
629
 
630
  __syncthreads();
@@ -649,7 +810,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
649
 
650
  if (((!needs_fixup && !is_fixup) || np > 1) && (ntiles == 4 || threadIdx.x % 4 < cols_per_thread)) {
651
  // Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale.
652
- ((float2 *) tile_K)[jc_cwm*(D2_padded/2) + D/4] = KQ_cmr;
653
  }
654
 
655
  __syncthreads();
@@ -676,11 +837,11 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
676
  constexpr int nmeta = np*cols_per_warp >= WARP_SIZE ? np*cols_per_warp/WARP_SIZE : 1;
677
 
678
  const int jc_meta = threadIdx.y*cols_per_warp + (np*cols_per_warp < WARP_SIZE ? threadIdx.x % (np*cols_per_warp) : threadIdx.x);
679
- float2 * const meta_ptr = ((float2 *) tile_K) + jc_meta*(D2_padded/2) + D/4;
680
  float2 meta[nmeta];
681
  #pragma unroll
682
  for (int imeta = 0; imeta < nmeta; ++imeta) {
683
- meta[imeta] = meta_ptr[imeta * WARP_SIZE * D2_padded/2];
684
  }
685
 
686
  float KQ_cmn = meta[0].x; // KQ combine max new, max between all parallel warps.
@@ -690,10 +851,9 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
690
  }
691
  #pragma unroll
692
  for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) {
693
- if (offset >= WARP_SIZE) {
694
- continue;
695
  }
696
- KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE));
697
  }
698
 
699
  float KQ_cms[nmeta]; // KQ combine max scale per warp.
@@ -709,10 +869,9 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
709
  }
710
  #pragma unroll
711
  for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) {
712
- if (offset >= WARP_SIZE) {
713
- continue;
714
  }
715
- KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE);
716
  }
717
 
718
  // Write back combined meta data:
@@ -720,7 +879,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
720
  for (int imeta = 0; imeta < nmeta; ++imeta) {
721
  if (np*cols_per_warp >= WARP_SIZE || threadIdx.x < np*cols_per_warp) {
722
  // Combined KQ max scale + rowsum.
723
- meta_ptr[imeta * WARP_SIZE * D2_padded/2] = make_float2(KQ_cms[imeta], KQ_crs);
724
  }
725
  }
726
 
@@ -736,88 +895,118 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
736
  }
737
  }
738
 
739
- if (np > 1) {
740
- __syncthreads();
741
- }
742
-
743
- if (np == 1 || threadIdx.y % np == 0) {
744
- // The first 2*2*gridDim.x*ncols floats in dstk_fixup are for storing max. values and row sums.
745
- // The values after that are for the partial results of the individual blocks.
746
- float2 * dstk_fixup_data = dstk_fixup + gridDim.x*(2*ncols) + blockIdx.x*(ncols*(D/2));
747
 
748
  #pragma unroll
749
- for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
750
- const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k);
751
- const int k0_stop = D/2 - (D/2) % (1*stride_k);
752
- const int stride_jc = WARP_SIZE / stride_k;
753
 
754
- if (k0_start == k0_stop) {
755
- continue;
756
  }
757
-
 
 
 
 
 
758
  #pragma unroll
759
- for (int jc0_dst = 0; jc0_dst < ncols; jc0_dst += (nwarps/np)*stride_jc) {
760
- const int jc_dst = jc0_dst + (threadIdx.y/np)*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
 
761
 
762
- if (jc0_dst + (nwarps/np)*stride_jc > ncols && jc_dst >= ncols) {
763
- break;
764
  }
 
 
765
 
766
- const int jc_tile_K = (jc_dst/cols_per_warp)*(np*cols_per_warp) + jc_dst % cols_per_warp;
 
 
 
 
 
767
 
768
- const int j_dst = jc_dst / ncols2;
769
- const int c_dst = jc_dst % ncols2;
 
 
 
770
 
771
- if (!is_fixup && jt*ncols1 + j_dst >= ne01) {
772
  continue;
773
  }
774
 
775
- const float * meta_j = (const float *) tile_K + jc_tile_K*D2_padded + D/2;
776
  #pragma unroll
777
- for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
778
- const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
779
 
780
- float2 dstk_val = make_float2(0.0f, 0.0f);
781
- #pragma unroll
782
- for (int ip = 0; ip < np; ++ip) {
783
- const float KQ_crs = np == 1 ? 1.0f : meta_j[ip*cols_per_warp * D2_padded + 0];
784
- const float2 dstk_val_add = __half22float2(tile_K[(jc_tile_K + ip*cols_per_warp) * D2_padded + k]);
785
- dstk_val.x += dstk_val_add.x*KQ_crs;
786
- dstk_val.y += dstk_val_add.y*KQ_crs;
787
  }
788
 
789
- if (!needs_fixup && !is_fixup) {
790
- const float KQ_rowsum_j = meta_j[1];
791
- dstk_val.x /= KQ_rowsum_j;
792
- dstk_val.y /= KQ_rowsum_j;
 
 
 
793
  }
794
 
795
- if (is_fixup) {
796
- dstk_fixup_data[jc_dst*(D/2) + k] = dstk_val;
797
- } else {
798
- dstk[((jt*ncols1 + j_dst)*ne02 + c_dst)*(D/2) + k] = dstk_val;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
799
  }
800
  }
801
  }
802
  }
803
- }
804
-
805
- if (np > 1) {
806
- __syncthreads();
807
  }
808
  #else
809
  GGML_UNUSED(Q_f2); GGML_UNUSED(K_h2); GGML_UNUSED(V_h2);
810
  GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup);
811
  GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap);
812
  GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_Q1);
813
- GGML_UNUSED(stride_Q2); GGML_UNUSED(stride_KV); GGML_UNUSED(stride_mask);
814
  GGML_UNUSED(jt); GGML_UNUSED(kb0_start); GGML_UNUSED(kb0_stop);
815
  NO_DEVICE_CODE;
816
  #endif // NEW_MMA_AVAILABLE
817
  }
818
 
819
- template<int D, int ncols1, int ncols2, int nwarps, int KQ_per_iter, int ntiles, bool use_logit_softcap>
820
- __launch_bounds__(nwarps*WARP_SIZE, 2)
821
  static __global__ void flash_attn_ext_f16(
822
  const char * __restrict__ Q,
823
  const char * __restrict__ K,
@@ -857,24 +1046,27 @@ static __global__ void flash_attn_ext_f16(
857
  #if defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
858
 
859
  // Skip unused kernel variants for faster compilation:
860
- if (use_logit_softcap && !(D == 128 || D == 256)) {
861
  NO_DEVICE_CODE;
862
  return;
863
  }
864
 
865
- static_assert(FATTN_KQ_STRIDE % KQ_per_iter == 0, "bad KQ_per_iter");
 
 
866
 
867
  const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
868
 
869
  const int stride_Q1 = nb01 / sizeof(float2);
870
  const int stride_Q2 = nb02 / sizeof(float2);
871
- const int stride_KV = nb11 / sizeof(half2);
 
872
  const int stride_mask = nb31 / sizeof(half2);
873
 
874
  const int iter_k = ne11 / FATTN_KQ_STRIDE;
875
  const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
876
 
877
- constexpr int kb_niter = FATTN_KQ_STRIDE / KQ_per_iter; // Number of kernel iterations per assigned KQ slice.
878
 
879
  // kbc == k block continuous, current index in continuous ijk space.
880
  int kbc = (blockIdx.x + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
@@ -893,9 +1085,9 @@ static __global__ void flash_attn_ext_f16(
893
 
894
  const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
895
  const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
896
- const half2 * V_h2 = (const half2 *) (V + nb12*(channel*ncols2 / gqa_ratio)); // K and V have same shape
897
  const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
898
- float2 * dstk = ((float2 *) dst) + channel*(ncols2 * D/2);
899
 
900
  const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
901
 
@@ -905,14 +1097,14 @@ static __global__ void flash_attn_ext_f16(
905
  constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
906
  if (kb0_start == 0) {
907
  constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
908
- flash_attn_ext_f16_process_tile<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup>
909
  (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
910
- ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
911
  } else {
912
  constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
913
- flash_attn_ext_f16_process_tile<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup>
914
  (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
915
- ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
916
  }
917
 
918
  kbc += iter_k;
@@ -931,9 +1123,9 @@ static __global__ void flash_attn_ext_f16(
931
 
932
  const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
933
  const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
934
- const half2 * V_h2 = (const half2 *) (V + nb12*(channel*ncols2 / gqa_ratio)); // K and V have same shape
935
  const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
936
- float2 * dstk = ((float2 *) dst) + channel*(ncols2 * D/2);
937
 
938
  const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
939
 
@@ -942,9 +1134,9 @@ static __global__ void flash_attn_ext_f16(
942
 
943
  constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
944
  constexpr bool needs_fixup = false;
945
- flash_attn_ext_f16_process_tile<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup>
946
  (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
947
- ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
948
  #else
949
  GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
950
  GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
@@ -960,28 +1152,42 @@ static __global__ void flash_attn_ext_f16(
960
  #endif // defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
961
  }
962
 
963
- template <int D, int ncols1, int ncols2>
964
  void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
 
 
 
 
 
 
 
 
 
 
 
 
965
  constexpr int ncols = ncols1 * ncols2;
966
- constexpr int KQ_per_iter = D <= 128 && ncols1 <= 64 ? 64 : 32;
967
- constexpr int nwarps = (KQ_per_iter == 32 && ncols <= 16) ? 2 : 4;
968
- constexpr int ntiles = ncols <= 8 ? 1 : (ncols <= 64 ? 2 : 4);
969
  constexpr int cols_per_warp = ntiles * tile_B::I;
 
 
 
970
 
971
- static_assert(D % tile_B::J == 0, "bad D");
 
972
  static_assert(ncols % cols_per_warp == 0, "bad ncols");
973
 
974
- const ggml_tensor * KQV = dst;
975
- const int id = ggml_cuda_get_device();
976
- const int cc = ggml_cuda_info().devices[id].cc;
 
 
977
 
978
- const int KQ_shared_rows = cp_async_available(cc) ? 2*KQ_per_iter : KQ_per_iter;
979
 
980
- const size_t nbytes_shared_KV = KQ_shared_rows * (D + 8) * sizeof(half);
981
- const size_t nbytes_shared_mask = ncols1 * (KQ_per_iter + 8) * sizeof(half);
982
- const size_t nbytes_shared_combine = nwarps*cols_per_warp * (D + 8) * sizeof(half);
983
-
984
- const size_t nbytes_shared_total = std::max(nbytes_shared_KV + nbytes_shared_mask, nbytes_shared_combine);
985
 
986
  float logit_softcap;
987
  memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
@@ -989,59 +1195,73 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
989
  fattn_kernel_t fattn_kernel;
990
  if (logit_softcap == 0.0f) {
991
  constexpr bool use_logit_softcap = false;
992
- fattn_kernel = flash_attn_ext_f16<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap>;
 
 
 
 
 
 
 
 
993
  } else {
994
  constexpr bool use_logit_softcap = true;
995
- fattn_kernel = flash_attn_ext_f16<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap>;
 
 
 
 
 
 
 
 
996
  }
997
 
998
- launch_fattn<D, ncols1, ncols2, KQ_per_iter>
999
  (ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, FATTN_KQ_STRIDE, true, true, true);
1000
  }
1001
 
1002
 
1003
- #define DECL_FATTN_MMA_F16_CASE(D, ncols1, ncols2) \
1004
- template void ggml_cuda_flash_attn_ext_mma_f16_case \
1005
- <D, ncols1, ncols2>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
1006
-
1007
- #define DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(D, ncols) \
1008
- extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/1, 1); \
1009
- extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/2, 2); \
1010
- extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/4, 4); \
1011
- extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/8, 8); \
1012
-
1013
- DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 8)
1014
- DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 8)
1015
- DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 8)
1016
- DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 8)
1017
- DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 8)
1018
- DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 8)
1019
-
1020
- DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 16)
1021
- DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 16)
1022
- DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 16)
1023
- DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 16)
1024
- DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 16)
1025
- DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 16)
1026
-
1027
- DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 32)
1028
- DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 32)
1029
- DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 32)
1030
- DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 32)
1031
- DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 32)
1032
- DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 32)
1033
-
1034
- DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 64)
1035
- DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 64)
1036
- DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 64)
1037
- DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 64)
1038
- DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 64)
1039
- DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 64)
1040
-
1041
- // Kernels with ncols == 128 are only 4% faster due to register pressure.
1042
- // DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 128)
1043
- // DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 128)
1044
- // DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 128)
1045
- // DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 128)
1046
- // DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128)
1047
- // DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 128) // Needs too much shared memory.
 
13
  typedef tile<16, 4, half2> tile_C_VKQ;
14
  typedef tile<16, 8, half2> tile_C_VKQ_16;
15
 
16
+ // Config options for specific head sizes.
17
+ // Should not affect results, only speed/register pressure/shared memory use.
18
+ //
19
+ // nbatch_fa: number of KV rows per softmax rescaling of KQ rowsums and VKQ accumulators.
20
+ // nwarps_max: maximum number of warps per CUDA block, up to 8 warps in total can run per SM (given enough shared memory).
21
+ // Q_in_reg: whether the Q values should be kept permanently in registers.
22
+ // nstages_target: targeted number of pipeline stages for cp_async (if available), 0 means synchronous data loading.
23
+ // nbatch_K2: number of K half2 values in direction of DKQ to load in parallel.
24
+ // nbatch_V2: number of V half2 values in direction of DV to load in parallel.
25
+ // nbatch_combine: number of VKQ half2 values in direction of DV to combine in parallel.
26
+
27
+ template <int DKQ, int DV>
28
+ struct fattn_mma_f16_config;
29
+
30
+ template <>
31
+ struct fattn_mma_f16_config< 64, 64> {
32
+ static constexpr int nbatch_fa = 64;
33
+ static constexpr int nwarps_max = 4;
34
+ static constexpr bool Q_in_reg = true;
35
+ static constexpr int nstages_target = 2;
36
+ static constexpr int nbatch_K2 = 32;
37
+ static constexpr int nbatch_V2 = 32;
38
+ static constexpr int nbatch_combine = 32;
39
+ };
40
+
41
+ template <>
42
+ struct fattn_mma_f16_config< 80, 80> {
43
+ static constexpr int nbatch_fa = 64;
44
+ static constexpr int nwarps_max = 4;
45
+ static constexpr bool Q_in_reg = true;
46
+ static constexpr int nstages_target = 2;
47
+ static constexpr int nbatch_K2 = 40;
48
+ static constexpr int nbatch_V2 = 40;
49
+ static constexpr int nbatch_combine = 40;
50
+ };
51
+
52
+ template <>
53
+ struct fattn_mma_f16_config< 96, 96> {
54
+ static constexpr int nbatch_fa = 64;
55
+ static constexpr int nwarps_max = 4;
56
+ static constexpr bool Q_in_reg = true;
57
+ static constexpr int nstages_target = 2;
58
+ static constexpr int nbatch_K2 = 48;
59
+ static constexpr int nbatch_V2 = 48;
60
+ static constexpr int nbatch_combine = 48;
61
+ };
62
+
63
+ template <>
64
+ struct fattn_mma_f16_config<112, 112> {
65
+ static constexpr int nbatch_fa = 64;
66
+ static constexpr int nwarps_max = 4;
67
+ static constexpr bool Q_in_reg = true;
68
+ static constexpr int nstages_target = 2;
69
+ static constexpr int nbatch_K2 = 56;
70
+ static constexpr int nbatch_V2 = 56;
71
+ static constexpr int nbatch_combine = 56;
72
+ };
73
+
74
+ template <>
75
+ struct fattn_mma_f16_config<128, 128> {
76
+ static constexpr int nbatch_fa = 64;
77
+ static constexpr int nwarps_max = 4;
78
+ static constexpr bool Q_in_reg = true;
79
+ static constexpr int nstages_target = 2;
80
+ static constexpr int nbatch_K2 = 64;
81
+ static constexpr int nbatch_V2 = 64;
82
+ static constexpr int nbatch_combine = 64;
83
+ };
84
+
85
+ template <>
86
+ struct fattn_mma_f16_config<256, 256> {
87
+ static constexpr int nbatch_fa = 32;
88
+ static constexpr int nwarps_max = 4;
89
+ static constexpr bool Q_in_reg = true;
90
+ static constexpr int nstages_target = 2;
91
+ static constexpr int nbatch_K2 = 128;
92
+ static constexpr int nbatch_V2 = 128;
93
+ static constexpr int nbatch_combine = 128;
94
+ };
95
+
96
+ template <>
97
+ struct fattn_mma_f16_config<576, 512> {
98
+ static constexpr int nbatch_fa = 32;
99
+ static constexpr int nwarps_max = 8;
100
+ static constexpr bool Q_in_reg = false;
101
+ static constexpr int nstages_target = 1;
102
+ static constexpr int nbatch_K2 = 160;
103
+ static constexpr int nbatch_V2 = 128;
104
+ static constexpr int nbatch_combine = 128;
105
+ };
106
+
107
+ // ------------------------------------------------------------------------------------------------------------------
108
+
109
+ template<int stride_tile, int nwarps, int nbatch_fa, bool use_cp_async>
110
  static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
111
+ const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV) {
 
112
 
113
+ // K/V data is loaded with decreasing granularity for D for better memory bandwidth.
114
+ // The minimum granularity with cp.async is 16 bytes, with synchronous data loading it's 4 bytes.
115
+
116
+ if (use_cp_async) {
117
+ constexpr int preload = 64;
118
+ constexpr int h2_per_chunk = 16/sizeof(half2);
119
+ const int chunks_per_row = D2 / h2_per_chunk;
120
 
121
+ const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared(tile_KV);
122
+
123
+ auto load = [&] __device__ (const int n) {
124
+ const int stride_k = WARP_SIZE >> n;
125
+ const int k0_start = stride_k == WARP_SIZE ? 0 : chunks_per_row - chunks_per_row % (2*stride_k);
126
+ const int k0_stop = chunks_per_row - chunks_per_row % (1*stride_k);
127
+ const int stride_i = WARP_SIZE / stride_k;
128
+
129
+ if (k0_start == k0_stop) {
130
+ return;
131
+ }
132
 
 
 
 
 
133
  #pragma unroll
134
+ for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) {
135
+ const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
 
136
 
137
+ if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) {
138
+ break;
139
+ }
 
 
 
140
 
 
 
 
141
  #pragma unroll
142
+ for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
143
+ const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
 
 
144
 
145
+ cp_async_cg_16<preload>(tile_KV_32 + i*(stride_tile*sizeof(half2)) + k*16, KV + i*stride_KV + k*h2_per_chunk);
146
+ }
147
+ }
148
+ };
149
+ ggml_cuda_unroll<5>{}(load);
150
+ } else {
151
+ static_assert(nbatch_fa % (4*nwarps) == 0, "out of bounds");
152
+ auto load = [&] __device__ (const int n) {
153
+ const int stride_k = WARP_SIZE >> n;
154
+ const int k0_start = stride_k == WARP_SIZE ? 0 : D2 - D2 % (2*stride_k);
155
+ const int k0_stop = D2 - D2 % (1*stride_k);
156
+ const int stride_i = WARP_SIZE / stride_k;
157
+
158
+ if (k0_start == k0_stop) {
159
+ return;
160
+ }
161
 
162
  #pragma unroll
163
+ for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) {
164
+ const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
165
+
166
+ if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) {
167
+ break;
168
+ }
169
 
170
  #pragma unroll
171
+ for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
172
+ const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
173
 
174
+ tile_KV[i*stride_tile + k] = KV[i*stride_KV + k];
175
+ }
176
  }
177
+ };
178
+ ggml_cuda_unroll<3>{}(load);
179
  }
180
  }
181
 
182
+ template<int ncols1, int nwarps, int nbatch_fa, bool use_cp_async>
183
  static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
184
  const half2 * const __restrict__ mask_h2, half2 * const __restrict__ tile_mask, const int stride_mask) {
185
+ static_assert(nbatch_fa == 2*WARP_SIZE || WARP_SIZE % nbatch_fa == 0, "bad KQ_per_iter");
186
+
187
+ if (use_cp_async) {
188
+ constexpr int preload = nbatch_fa >= 32 ? nbatch_fa * sizeof(half) : 64;
189
+ constexpr int cols_per_warp = 8*WARP_SIZE/nbatch_fa;
190
+ constexpr int stride_j = nwarps * cols_per_warp;
191
 
192
+ const unsigned int tile_mask_32 = ggml_cuda_cvta_generic_to_shared(tile_mask);
193
 
194
  #pragma unroll
195
+ for (int j0 = 0; j0 < ncols1; j0 += stride_j) {
196
+ const int j = j0 + threadIdx.y*cols_per_warp +
197
+ (nbatch_fa == 2*WARP_SIZE ? threadIdx.x / (WARP_SIZE/4) : threadIdx.x / (WARP_SIZE/cols_per_warp));
198
 
199
+ if (j0 + stride_j > ncols1 && j >= ncols1) {
200
+ break;
201
+ }
202
 
203
+ const int i = 4 * (threadIdx.x % (nbatch_fa/8));
204
 
205
+ cp_async_cg_16<preload>(tile_mask_32 + j*(nbatch_fa*sizeof(half) + 16) + i*sizeof(half2), mask_h2 + j*stride_mask + i);
206
+ }
207
+ return;
208
  }
209
+
210
+ constexpr int cols_per_warp = 2*WARP_SIZE/nbatch_fa;
211
  constexpr int stride_j = nwarps * cols_per_warp;
212
  #pragma unroll
213
  for (int j0 = 0; j0 < ncols1; j0 += stride_j) {
214
+ const int j = j0 + threadIdx.y*cols_per_warp + (nbatch_fa == 2*WARP_SIZE ? 0 : threadIdx.x / (WARP_SIZE/cols_per_warp));
215
 
216
  if (j0 + stride_j > ncols1 && j >= ncols1) {
217
  break;
218
  }
219
 
220
+ const int i = nbatch_fa == 2*WARP_SIZE ? threadIdx.x : threadIdx.x % (WARP_SIZE/cols_per_warp);
221
 
222
+ tile_mask[j*(nbatch_fa/2 + 4) + i] = mask_h2[j*stride_mask + i];
223
  }
 
224
  }
225
 
226
+ template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool needs_fixup, bool is_fixup, bool last_iter>
227
  static __device__ __forceinline__ void flash_attn_ext_f16_iter(
228
  const float2 * const __restrict__ Q_f2,
229
  const half2 * const __restrict__ K_h2,
 
236
  const float logit_softcap,
237
  const int ne01,
238
  const int ne02,
239
+ const int stride_K,
240
+ const int stride_V,
241
  const int stride_mask,
242
  const int jt,
243
+ half2 * const __restrict__ tile_Q,
244
  half2 * const __restrict__ tile_K,
245
  half2 * const __restrict__ tile_V,
246
  half2 * const __restrict__ tile_mask,
 
250
  float * const __restrict__ KQ_rowsum,
251
  const int kb0) {
252
  #ifdef NEW_MMA_AVAILABLE
253
+ typedef fattn_mma_f16_config<DKQ, DV> c;
254
+
255
+ #ifdef CP_ASYNC_AVAILABLE
256
+ constexpr int nstages = c::nstages_target;
257
+ #else
258
+ constexpr int nstages = 0;
259
+ #endif // CP_ASYNC_AVAILABLE
260
+
261
  constexpr int cols_per_warp = ntiles * tile_B::I;
262
  constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles;
263
  constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
 
264
 
265
+ constexpr int stride_tile_Q = DKQ/2 + 4;
266
+ constexpr int stride_tile_K = c::nbatch_K2 + 4;
267
+ constexpr int stride_tile_V = c::nbatch_V2 + 4;
268
+
269
+ const int k_VKQ_0 = kb0 * c::nbatch_fa;
270
+ tile_C_KQ KQ_C[c::nbatch_fa/(np*tile_C_KQ::I) * ntiles];
271
 
272
  // Use wide variants of tiles if ntiles >= 2.
273
  tile_B_16 * Q_B_16 = (tile_B_16 *) Q_B;
274
  tile_C_VKQ_16 * VKQ_C_16 = (tile_C_VKQ_16 *) VKQ_C;
275
  tile_C_KQ_16 * KQ_C_16 = (tile_C_KQ_16 *) KQ_C;
276
 
277
+ if constexpr (nstages > 1) {
278
+ static_assert(c::nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading");
279
+ constexpr bool use_cp_async = true;
280
+ cp_async_wait_all();
281
+ __syncthreads();
282
+ flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
283
+ (V_h2 + k_VKQ_0*stride_V, tile_V, c::nbatch_V2, stride_V);
284
+ } else {
285
+ constexpr bool use_cp_async = nstages == 1;
286
+ if (ncols2 > 1 || mask_h2) {
287
+ flash_attn_ext_f16_load_mask<ncols1, nwarps, c::nbatch_fa, use_cp_async>(mask_h2 + k_VKQ_0/2, tile_mask, stride_mask);
288
+ }
289
  }
 
 
 
290
 
 
291
  #pragma unroll
292
+ for (int k0_start = 0; k0_start < DKQ/2; k0_start += c::nbatch_K2) {
293
+ const int k0_stop = k0_start + c::nbatch_K2 < DKQ/2 ? k0_start + c::nbatch_K2 : DKQ/2;
294
+ const int k0_diff = k0_stop - k0_start;
295
+
296
+ if (nstages <= 1) {
297
+ constexpr bool use_cp_async = nstages == 1;
298
+ flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
299
+ (K_h2 + k_VKQ_0*stride_K + k0_start, tile_K, k0_diff, stride_K);
300
+ if (use_cp_async) {
301
+ cp_async_wait_all();
302
+ }
303
+ __syncthreads();
304
+ }
305
+
306
+ // Calculate tile of KQ:
307
+ if constexpr (c::Q_in_reg) {
308
  #pragma unroll
309
+ for (int i_KQ_00 = 0; i_KQ_00 < c::nbatch_fa; i_KQ_00 += np*tile_A::I) {
310
+ const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*tile_A::I;
 
 
 
 
311
  #pragma unroll
312
+ for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += tile_A::J) {
313
+ tile_A K_A;
314
+ load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K);
315
+ if (ntiles == 1) {
316
+ mma(KQ_C[i_KQ_00/(np*tile_A::I)], K_A, Q_B[k_KQ_0/tile_A::J]);
317
+ } else {
318
+ #pragma unroll
319
+ for (int t = 0; t < ntiles/2; ++t) {
320
+ // Wide version of KQ_C is column-major => swap A and B.
321
+ mma(KQ_C_16[i_KQ_00/(np*tile_A::I) * ntiles/2 + t], Q_B_16[k_KQ_0/tile_A::J * ntiles/2 + t], K_A);
322
+ }
323
+ }
324
+ }
325
+ }
326
+ } else {
327
+ static_assert(ntiles == 2, "ntiles != 2 not implemented");
328
+ #pragma unroll
329
+ for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += tile_A::J) {
330
+ load_ldmatrix(Q_B_16[0], tile_Q + (threadIdx.y / np)*(tile_B_16::I*stride_tile_Q) + k_KQ_0, stride_tile_Q);
331
+
332
+ #pragma unroll
333
+ for (int i_KQ_00 = 0; i_KQ_00 < c::nbatch_fa; i_KQ_00 += np*tile_A::I) {
334
+ const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*tile_A::I;
335
+
336
+ tile_A K_A;
337
+ load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K);
338
+
339
  // Wide version of KQ_C is column-major => swap A and B.
340
+ mma(KQ_C_16[i_KQ_00/(np*tile_A::I)], Q_B_16[0], K_A);
341
  }
342
  }
343
  }
 
344
 
345
+ if (nstages <= 1) {
346
+ __syncthreads(); // Only needed if tile_K == tile_V.
347
+ }
348
+ }
349
 
350
  if (use_logit_softcap) {
351
+ static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size");
352
  #pragma unroll
353
+ for (int i = 0; i < c::nbatch_fa/(np*tile_C_KQ::I) * ntiles; ++i) {
354
  #pragma unroll
355
  for (int l = 0; l < tile_C_KQ::ne; ++l) {
356
  KQ_C[i].x[l] = logit_softcap*tanhf(KQ_C[i].x[l]);
 
368
  if (ntiles == 1) {
369
  if (ncols2 > 1 || mask_h2) {
370
  #pragma unroll
371
+ for (int i00 = 0; i00 < c::nbatch_fa; i00 += np*tile_C_KQ::I) {
372
  const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ::I;
373
  #pragma unroll
374
  for (int l = 0; l < tile_C_KQ::ne; ++l) {
 
376
  const int j = ((threadIdx.y / np)*tile_C_KQ::J + tile_C_KQ::get_j(l)) / ncols2;
377
 
378
  KQ_C[i00/(np*tile_C_KQ::I)].x[l] += slope *
379
+ __half2float(((const half *) tile_mask)[j*(c::nbatch_fa + 8) + i]);
380
  }
381
  }
382
  }
383
 
384
  // Calculate softmax for each KQ column using the current max. value.
385
  // The divisor is stored in KQ_rowsum and will be applied at the end.
386
+ static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size");
387
  #pragma unroll
388
+ for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ::I); ++k) {
389
  #pragma unroll
390
  for (int l = 0; l < tile_C_KQ::ne; ++l) {
391
  KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k].x[l]);
 
401
  }
402
  }
403
 
404
+ static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size");
 
405
  #pragma unroll
406
+ for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ::I); ++k) {
407
  #pragma unroll
408
  for (int l = 0; l < tile_C_KQ::ne; ++l) {
409
  KQ_C[k].x[l] = expf(KQ_C[k].x[l] - KQ_max_new[l % 2]);
 
414
  } else { // ntiles > 1
415
  if (ncols2 > 1 || mask_h2) {
416
  #pragma unroll
417
+ for (int i00 = 0; i00 < c::nbatch_fa; i00 += np*tile_C_KQ_16::J) {
418
  const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ_16::J;
419
  #pragma unroll
420
  for (int t = 0; t < ntiles/2; ++t) {
 
423
  const int i = (i0 + tile_C_KQ_16::get_j(l0)) / 2;
424
  const int j = ((threadIdx.y / np)*cols_per_warp + t*tile_C_KQ_16::I + tile_C_KQ_16::get_i(l0)) / ncols2;
425
 
426
+ const float2 tmp = __half22float2(tile_mask[j*(c::nbatch_fa/2 + 4) + i]);
427
  const int KQ_index = i00/(np*tile_C_KQ_16::J) * ntiles/2 + t;
428
  KQ_C_16[KQ_index].x[l0 + 0] += slope*tmp.x;
429
  KQ_C_16[KQ_index].x[l0 + 1] += slope*tmp.y;
 
434
 
435
  // Calculate softmax for each KQ column using the current max. value.
436
  // The divisor is stored in KQ_rowsum and will be applied at the end.
437
+ static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size");
438
  #pragma unroll
439
+ for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ_16::J); ++k) {
440
  #pragma unroll
441
  for (int t = 0; t < ntiles/2; ++t) {
442
  #pragma unroll
 
456
  }
457
  }
458
 
459
+ static_assert(c::nbatch_fa % (np*tile_C_KQ_16::J) == 0, "bad loop size");
460
  #pragma unroll
461
+ for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ_16::J); ++k) {
462
  #pragma unroll
463
  for (int t = 0; t < ntiles/2; ++t) {
464
  #pragma unroll
 
487
  if (ntiles == 1) {
488
  const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]);
489
  #pragma unroll
490
+ for (int i = 0; i < DV/tile_C_VKQ::I; ++i) {
491
  #pragma unroll
492
  for (int l = 0; l < tile_C_VKQ::ne; ++l) {
493
  VKQ_C[i].x[l] *= KQ_max_scale_h2;
 
498
  for (int col = 0; col < cols_per_thread; ++col) {
499
  const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]);
500
  #pragma unroll
501
+ for (int i = 0; i < DV/tile_C_VKQ_16::J; ++i) {
502
  #pragma unroll
503
  for (int l0 = 0; l0 < tile_C_VKQ_16::ne; l0 += 2) {
504
  VKQ_C_16[i*ntiles/2 + col/2].x[l0 + col % 2] *= KQ_max_scale_h2;
 
509
  }
510
 
511
  // Convert KQ C tiles into B tiles for VKQ calculation:
512
+ tile_B B[c::nbatch_fa/(np*2*tile_B::J) * ntiles];
513
  tile_B_16 * B_16 = (tile_B_16 *) B;
514
+ static_assert(c::nbatch_fa % (np*2*tile_B::J) == 0, "bad loop size");
515
  if (ntiles == 1) {
516
  #pragma unroll
517
+ for (int k = 0; k < c::nbatch_fa/(np*2*tile_B::J); ++k) {
518
  B[k] = get_transposed(get_half2(KQ_C[k]));
519
  }
520
  } else {
521
+ for (int k = 0; k < c::nbatch_fa/(np*2*tile_B_16::J); ++k) {
522
  #pragma unroll
523
  for (int t = 0; t < ntiles/2; ++t) {
524
  B_16[k*ntiles/2 + t] = get_half2(KQ_C_16[k*ntiles/2 + t]);
 
526
  }
527
  }
528
 
529
+ if (nstages > 1) {
530
+ // Preload K tile for next iteration:
531
+ constexpr bool use_cp_async = true;
532
+ cp_async_wait_all();
533
+ __syncthreads();
534
+ if (!last_iter) {
535
+ if (ncols2 > 1 || mask_h2) {
536
+ flash_attn_ext_f16_load_mask<ncols1, nwarps, c::nbatch_fa, use_cp_async>
537
+ (mask_h2 + (k_VKQ_0 + c::nbatch_fa)/2, tile_mask, stride_mask);
538
+ }
539
+ flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
540
+ (K_h2 + (k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, c::nbatch_K2, stride_K);
541
  }
 
542
  }
 
 
 
 
543
 
 
544
  #pragma unroll
545
+ for (int i0_start = 0; i0_start < DV; i0_start += 2*c::nbatch_V2) {
546
+ const int i0_stop = i0_start + 2*c::nbatch_V2 < DV ? i0_start + 2*c::nbatch_V2 : DV;
547
+ const int i0_diff = i0_stop - i0_start;
548
+
549
+ if (nstages == 1) {
550
+ constexpr bool use_cp_async = nstages == 1;
551
+ flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
552
+ (V_h2 + k_VKQ_0*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V);
553
+ if (use_cp_async) {
554
+ cp_async_wait_all();
555
+ }
556
+ __syncthreads();
557
+ }
558
+
559
+ // Calculate VKQ tile:
560
+ #pragma unroll
561
+ for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += tile_C_VKQ::I) {
562
+ static_assert((c::nbatch_fa/2) % (np*tile_A::J) == 0, "bad loop size");
563
  #pragma unroll
564
+ for (int k00 = 0; k00 < c::nbatch_fa/2; k00 += np*tile_A::J) {
565
+ const int k0 = k00 + (threadIdx.y % np)*tile_A::J;
566
 
567
+ tile_A A;
568
+ load_ldmatrix_trans(A, tile_V + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
569
+ if (ntiles == 1) {
570
+ mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]);
571
+ } else {
572
  #pragma unroll
573
+ for (int t = 0; t < ntiles/2; ++t) {
574
+ // Wide version of VKQ_C is column-major => swap A and B.
575
+ mma(VKQ_C_16[i_VKQ_0/tile_C_VKQ::I * ntiles/2 + t], B_16[k00/(np*tile_A::J) * ntiles/2 + t], A);
576
+ }
577
  }
578
  }
579
  }
 
 
 
 
 
580
 
581
+ if (nstages <= 1) {
582
+ __syncthreads(); // Only needed if tile_K == tile_V.
583
+ }
584
+ }
585
  #else
586
  GGML_UNUSED(Q_f2); GGML_UNUSED(K_h2); GGML_UNUSED(V_h2);
587
  GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup);
588
  GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap);
589
+ GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_K); GGML_UNUSED(stride_V);
590
  GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K);
591
  GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K);
592
  GGML_UNUSED(tile_V); GGML_UNUSED(tile_mask); GGML_UNUSED(Q_B);
 
596
  #endif // NEW_MMA_AVAILABLE
597
  }
598
 
599
+ template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool needs_fixup, bool is_fixup>
600
  static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
601
  const float2 * const __restrict__ Q_f2,
602
  const half2 * const __restrict__ K_h2,
 
611
  const int ne02,
612
  const int stride_Q1,
613
  const int stride_Q2,
614
+ const int stride_K,
615
+ const int stride_V,
616
  const int stride_mask,
617
  const int jt,
618
  const int kb0_start,
 
620
  #ifdef NEW_MMA_AVAILABLE
621
  //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
622
 
623
+ typedef fattn_mma_f16_config<DKQ, DV> c;
624
+
625
+ #ifdef CP_ASYNC_AVAILABLE
626
+ constexpr int nstages = c::nstages_target;
627
+ #else
628
+ constexpr int nstages = 0;
629
+ #endif // CP_ASYNC_AVAILABLE
630
+
631
  constexpr int ncols = ncols1 * ncols2;
632
  constexpr int cols_per_warp = ntiles * tile_B::I;
633
  constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles;
 
635
 
636
  static_assert(nwarps * (cols_per_warp/ncols2) % ncols1 == 0, "bad nwarps");
637
 
638
+ constexpr int stride_tile_Q = DKQ/2 + 4;
639
+ constexpr int stride_tile_K = c::nbatch_K2 + 4;
640
+ constexpr int stride_tile_V = c::nbatch_V2 + 4;
641
 
642
+ constexpr int stride_tile_KV_max = stride_tile_K > stride_tile_V ? stride_tile_K : stride_tile_V;
643
 
644
+ extern __shared__ half2 tile_Q[];
645
+ half2 * tile_K = c::Q_in_reg ? tile_Q : tile_Q + ncols * stride_tile_Q;
646
+ half2 * tile_V = nstages > 1 ? tile_K + c::nbatch_fa * stride_tile_K : tile_K;
647
+ half2 * tile_mask = nstages > 1 ? tile_V + c::nbatch_fa * stride_tile_V : tile_V + c::nbatch_fa * stride_tile_KV_max;
 
 
 
 
648
 
649
+ tile_B Q_B[(c::Q_in_reg ? DKQ/(2*tile_B::J) : 1) * ntiles];
650
+ tile_C_VKQ VKQ_C[DV/tile_C_VKQ::I * ntiles];
651
 
652
  tile_B_16 * Q_B_16 = (tile_B_16 *) Q_B;
653
  tile_C_VKQ_16 * VKQ_C_16 = (tile_C_VKQ_16 *) VKQ_C;
 
659
  KQ_max[col] = -FLT_MAX/2.0f;
660
  }
661
 
662
+ // Load Q data into tile_Q, either temporarily or permanently.
663
+ // Q in registers is faster, but register pressure is the biggest bottleneck.
664
  // The loading is done with decreasing granularity for D for better memory bandwidth.
665
  const half2 scale_h2 = make_half2(scale, scale);
666
  #pragma unroll
667
  for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
668
+ const int k0_start = stride_k == WARP_SIZE ? 0 : DKQ/2 - (DKQ/2) % (2*stride_k);
669
+ const int k0_stop = DKQ/2 - (DKQ/2) % (1*stride_k);
670
  const int stride_jc = WARP_SIZE / stride_k;
671
 
672
  if (k0_start == k0_stop) {
 
690
  const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
691
 
692
  const float2 tmp = Q_f2[(jt*ncols1 + j)*stride_Q1 + c*stride_Q2 + k];
693
+ tile_Q[jc*stride_tile_Q + k] = scale_h2 * make_half2(tmp.x, tmp.y);
694
  }
695
  } else {
696
  #pragma unroll
697
  for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
698
  const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
699
 
700
+ tile_Q[jc*stride_tile_Q + k] = make_half2(0.0f, 0.0f);
701
  }
702
  }
703
  }
 
705
 
706
  __syncthreads();
707
 
708
+ if (c::Q_in_reg) {
709
  const int j0 = (threadIdx.y / np) * cols_per_warp;
710
 
711
  #pragma unroll
712
+ for (int k0 = 0; k0 < DKQ/2; k0 += tile_B::J) {
713
  if (ntiles == 1) {
714
+ load_ldmatrix(Q_B[k0/tile_B::J], tile_Q + j0*stride_tile_Q + k0, stride_tile_Q);
715
  } else {
716
  #pragma unroll
717
  for (int t = 0; t < ntiles/2; ++t) {
718
  load_ldmatrix(Q_B_16[k0/tile_B_16::J * ntiles/2 + t],
719
+ tile_Q + (j0 + t*tile_B_16::I)*stride_tile_Q + k0, stride_tile_Q);
720
  }
721
  }
722
  }
 
724
 
725
  __syncthreads();
726
 
727
+ // Preload mask and K data for first iteration when using cp_async with multiple stages:
728
+ if constexpr (nstages > 1) {
729
+ static_assert(c::nbatch_K2 == DKQ/2, "batching not implemented for multi-stage pipeline");
730
+ constexpr bool use_cp_async = true;
731
+ if (ncols2 > 1 || mask_h2) {
732
+ flash_attn_ext_f16_load_mask<ncols1, nwarps, c::nbatch_fa, use_cp_async>
733
+ (mask_h2 + kb0_start*c::nbatch_fa/2, tile_mask, stride_mask);
734
+ }
735
+ flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
736
+ (K_h2 + kb0_start*c::nbatch_fa*stride_K, tile_K, c::nbatch_K2, stride_K);
737
  }
 
 
738
 
739
  // Iterate over ne11 == previous tokens:
740
  for (int kb0 = kb0_start; kb0 < kb0_stop-1; ++kb0) {
741
  constexpr bool last_iter = false;
742
+ flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup, last_iter>
743
  (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
744
+ ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
745
  }
746
  { // kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
747
  constexpr bool last_iter = true;
748
+ flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup, last_iter>
749
  (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
750
+ ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1);
751
  }
752
 
753
+ // With multi-stage loading there is no __syncthreads at the end of the iter,
754
  // there can be a race condition on shared memory access for combining/writing back results.
755
+ if (nstages > 1 && nwarps*cols_per_warp > c::nbatch_fa) {
 
756
  __syncthreads();
757
  }
 
758
 
759
  // Finally, sum up partial KQ rowsums.
760
  // The partial sums are spread across 8/4 threads each, does not need full reduce.
 
770
  }
771
  }
772
 
773
+ // Combine VKQ accumulator values if np > 1.
774
+ // It's also faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
775
+ // So also write VKQ accumulators to shared memory in column-major format if np == 1.
 
 
 
 
 
776
 
777
+ constexpr int nbatch_combine = c::Q_in_reg ? DV/2 : DV/4;
778
+ constexpr int tile_stride = nbatch_combine + 4;
779
+ static_assert((DV/2) % nbatch_combine == 0, "bad nbatch_combine");
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
780
 
781
  if constexpr (ntiles == 1) {
782
  const int jc_cwmo = (threadIdx.x % (2*tile_C_VKQ::J)) / tile_C_VKQ::J; // jc combine write meta offset
 
785
 
786
  if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*tile_C_VKQ::J) {
787
  // Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale.
788
+ ((float2 *) tile_Q)[jc_cwm*(tile_stride/2) + nbatch_combine/2] = KQ_cmr;
789
  }
790
 
791
  __syncthreads();
 
810
 
811
  if (((!needs_fixup && !is_fixup) || np > 1) && (ntiles == 4 || threadIdx.x % 4 < cols_per_thread)) {
812
  // Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale.
813
+ ((float2 *) tile_Q)[jc_cwm*(tile_stride/2) + nbatch_combine/2] = KQ_cmr;
814
  }
815
 
816
  __syncthreads();
 
837
  constexpr int nmeta = np*cols_per_warp >= WARP_SIZE ? np*cols_per_warp/WARP_SIZE : 1;
838
 
839
  const int jc_meta = threadIdx.y*cols_per_warp + (np*cols_per_warp < WARP_SIZE ? threadIdx.x % (np*cols_per_warp) : threadIdx.x);
840
+ float2 * const meta_ptr = ((float2 *) tile_Q) + jc_meta*(tile_stride/2) + nbatch_combine/2;
841
  float2 meta[nmeta];
842
  #pragma unroll
843
  for (int imeta = 0; imeta < nmeta; ++imeta) {
844
+ meta[imeta] = meta_ptr[imeta * WARP_SIZE * tile_stride/2];
845
  }
846
 
847
  float KQ_cmn = meta[0].x; // KQ combine max new, max between all parallel warps.
 
851
  }
852
  #pragma unroll
853
  for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) {
854
+ if (offset < WARP_SIZE) {
855
+ KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE));
856
  }
 
857
  }
858
 
859
  float KQ_cms[nmeta]; // KQ combine max scale per warp.
 
869
  }
870
  #pragma unroll
871
  for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) {
872
+ if (offset < WARP_SIZE) {
873
+ KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE);
874
  }
 
875
  }
876
 
877
  // Write back combined meta data:
 
879
  for (int imeta = 0; imeta < nmeta; ++imeta) {
880
  if (np*cols_per_warp >= WARP_SIZE || threadIdx.x < np*cols_per_warp) {
881
  // Combined KQ max scale + rowsum.
882
+ meta_ptr[imeta * WARP_SIZE * tile_stride/2] = make_float2(KQ_cms[imeta], KQ_crs);
883
  }
884
  }
885
 
 
895
  }
896
  }
897
 
898
+ #pragma unroll
899
+ for (int k00 = 0; k00 < DV/2; k00 += nbatch_combine) {
900
+ if (ntiles == 1) {
901
+ const int jc_cwd = threadIdx.y*tile_B::I + tile_B::get_i(-1); // jc combine write data
902
+ #pragma unroll
903
+ for (int k0 = 0; k0 < nbatch_combine; k0 += tile_B::J) {
904
+ const tile_B B = get_transposed(VKQ_C[(k00 + k0)/tile_B::J]); // Conversion of C to B matrix puts it in column-major format.
 
905
 
906
  #pragma unroll
907
+ for (int l = 0; l < tile_B::ne; ++l) {
908
+ const int k = k0 + tile_B::get_j(l);
 
 
909
 
910
+ tile_Q[jc_cwd*tile_stride + k] = B.x[l];
911
+ }
912
  }
913
+ } else {
914
+ #pragma unroll
915
+ for (int t = 0; t < ntiles/2; ++t) {
916
+ const int j0 = threadIdx.y*cols_per_warp + t*tile_C_VKQ_16::I;
917
+ #pragma unroll
918
+ for (int k0 = 0; k0 < nbatch_combine; k0 += tile_C_VKQ_16::J) {
919
  #pragma unroll
920
+ for (int l = 0; l < tile_C_VKQ_16::ne; ++l) {
921
+ const int j = j0 + tile_C_VKQ_16::get_i(l);
922
+ const int k = k0 + tile_C_VKQ_16::get_j(l);
923
 
924
+ tile_Q[j*tile_stride + k] = VKQ_C_16[(k00 + k0)/tile_C_VKQ_16::J * ntiles/2 + t].x[l];
925
+ }
926
  }
927
+ }
928
+ }
929
 
930
+ __syncthreads();
931
+
932
+ if (np == 1 || threadIdx.y % np == 0) {
933
+ // The first 2*2*gridDim.x*ncols floats in dstk_fixup are for storing max. values and row sums.
934
+ // The values after that are for the partial results of the individual blocks.
935
+ float2 * dstk_fixup_data = dstk_fixup + gridDim.x*(2*ncols) + blockIdx.x*(ncols*(DV/2));
936
 
937
+ #pragma unroll
938
+ for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
939
+ const int k0_start = stride_k == WARP_SIZE ? 0 : nbatch_combine - nbatch_combine % (2*stride_k);
940
+ const int k0_stop = nbatch_combine - nbatch_combine % (1*stride_k);
941
+ const int stride_jc = WARP_SIZE / stride_k;
942
 
943
+ if (k0_start == k0_stop) {
944
  continue;
945
  }
946
 
 
947
  #pragma unroll
948
+ for (int jc0_dst = 0; jc0_dst < ncols; jc0_dst += (nwarps/np)*stride_jc) {
949
+ const int jc_dst = jc0_dst + (threadIdx.y/np)*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
950
 
951
+ if (jc0_dst + (nwarps/np)*stride_jc > ncols && jc_dst >= ncols) {
952
+ break;
 
 
 
 
 
953
  }
954
 
955
+ const int jc_tile_K = (jc_dst/cols_per_warp)*(np*cols_per_warp) + jc_dst % cols_per_warp;
956
+
957
+ const int j_dst = jc_dst / ncols2;
958
+ const int c_dst = jc_dst % ncols2;
959
+
960
+ if (!is_fixup && jt*ncols1 + j_dst >= ne01) {
961
+ continue;
962
  }
963
 
964
+ const float * meta_j = (const float *) tile_Q + jc_tile_K*tile_stride + nbatch_combine;
965
+ #pragma unroll
966
+ for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
967
+ const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
968
+
969
+ float2 dstk_val = make_float2(0.0f, 0.0f);
970
+ #pragma unroll
971
+ for (int ip = 0; ip < np; ++ip) {
972
+ const float KQ_crs = np == 1 ? 1.0f : meta_j[ip*cols_per_warp * tile_stride + 0];
973
+ const float2 dstk_val_add = __half22float2(tile_Q[(jc_tile_K + ip*cols_per_warp) * tile_stride + k]);
974
+ dstk_val.x += dstk_val_add.x*KQ_crs;
975
+ dstk_val.y += dstk_val_add.y*KQ_crs;
976
+ }
977
+
978
+ if (!needs_fixup && !is_fixup) {
979
+ const float KQ_rowsum_j = meta_j[1];
980
+ dstk_val.x /= KQ_rowsum_j;
981
+ dstk_val.y /= KQ_rowsum_j;
982
+ }
983
+
984
+ if (is_fixup) {
985
+ dstk_fixup_data[jc_dst*(DV/2) + k00 + k] = dstk_val;
986
+ } else {
987
+ dstk[((jt*ncols1 + j_dst)*ne02 + c_dst)*(DV/2) + k00 + k] = dstk_val;
988
+ }
989
  }
990
  }
991
  }
992
  }
993
+ if (np > 1) {
994
+ __syncthreads();
995
+ }
 
996
  }
997
  #else
998
  GGML_UNUSED(Q_f2); GGML_UNUSED(K_h2); GGML_UNUSED(V_h2);
999
  GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup);
1000
  GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap);
1001
  GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_Q1);
1002
+ GGML_UNUSED(stride_Q2); GGML_UNUSED(stride_K); GGML_UNUSED(stride_V); GGML_UNUSED(stride_mask);
1003
  GGML_UNUSED(jt); GGML_UNUSED(kb0_start); GGML_UNUSED(kb0_stop);
1004
  NO_DEVICE_CODE;
1005
  #endif // NEW_MMA_AVAILABLE
1006
  }
1007
 
1008
+ template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap>
1009
+ __launch_bounds__(nwarps*WARP_SIZE, 1)
1010
  static __global__ void flash_attn_ext_f16(
1011
  const char * __restrict__ Q,
1012
  const char * __restrict__ K,
 
1046
  #if defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
1047
 
1048
  // Skip unused kernel variants for faster compilation:
1049
+ if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) {
1050
  NO_DEVICE_CODE;
1051
  return;
1052
  }
1053
 
1054
+ typedef fattn_mma_f16_config<DKQ, DV> c;
1055
+
1056
+ static_assert(FATTN_KQ_STRIDE % fattn_mma_f16_config<DKQ, DV>::nbatch_fa == 0, "bad nbatch_fa");
1057
 
1058
  const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
1059
 
1060
  const int stride_Q1 = nb01 / sizeof(float2);
1061
  const int stride_Q2 = nb02 / sizeof(float2);
1062
+ const int stride_K = nb11 / sizeof(half2);
1063
+ const int stride_V = nb21 / sizeof(half2);
1064
  const int stride_mask = nb31 / sizeof(half2);
1065
 
1066
  const int iter_k = ne11 / FATTN_KQ_STRIDE;
1067
  const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
1068
 
1069
+ constexpr int kb_niter = FATTN_KQ_STRIDE / c::nbatch_fa; // Number of kernel iterations per assigned KQ slice.
1070
 
1071
  // kbc == k block continuous, current index in continuous ijk space.
1072
  int kbc = (blockIdx.x + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
 
1085
 
1086
  const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
1087
  const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
1088
+ const half2 * V_h2 = (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
1089
  const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
1090
+ float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
1091
 
1092
  const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
1093
 
 
1097
  constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
1098
  if (kb0_start == 0) {
1099
  constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
1100
+ flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup>
1101
  (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
1102
+ ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
1103
  } else {
1104
  constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
1105
+ flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup>
1106
  (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
1107
+ ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
1108
  }
1109
 
1110
  kbc += iter_k;
 
1123
 
1124
  const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
1125
  const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
1126
+ const half2 * V_h2 = (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio)); // K and V have same shape
1127
  const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
1128
+ float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
1129
 
1130
  const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
1131
 
 
1134
 
1135
  constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
1136
  constexpr bool needs_fixup = false;
1137
+ flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup>
1138
  (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
1139
+ ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
1140
  #else
1141
  GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
1142
  GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
 
1152
  #endif // defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
1153
  }
1154
 
1155
+ template <int DKQ, int DV, int ncols1, int ncols2>
1156
  void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
1157
+ const ggml_tensor * KQV = dst;
1158
+ const int id = ggml_cuda_get_device();
1159
+ const int cc = ggml_cuda_info().devices[id].cc;
1160
+
1161
+ typedef fattn_mma_f16_config<DKQ, DV> c;
1162
+
1163
+ constexpr int nbatch_K2 = c::nbatch_K2 < 1 ? DKQ/2 : c::nbatch_K2;
1164
+ constexpr int nbatch_V2 = c::nbatch_V2 < 1 ? DV /2 : c::nbatch_V2;
1165
+ constexpr int nbatch_combine = c::nbatch_combine < 1 ? DV /2 : c::nbatch_combine;
1166
+
1167
+ const int nstages = cp_async_available(cc) ? c::nstages_target : 0;
1168
+
1169
  constexpr int ncols = ncols1 * ncols2;
1170
+ constexpr int ntiles = ncols <= 8 ? 1 : 2; // Number of tiles per warp.
 
 
1171
  constexpr int cols_per_warp = ntiles * tile_B::I;
1172
+ constexpr int nwarps_max_x = ncols / cols_per_warp;
1173
+ constexpr int nwarps_max_y = c::nbatch_fa / tile_A::I;
1174
+ constexpr int nwarps = nwarps_max_x*nwarps_max_y <= c::nwarps_max ? nwarps_max_x*nwarps_max_y : c::nwarps_max;
1175
 
1176
+ static_assert(DKQ % tile_B::J == 0, "bad DKQ");
1177
+ static_assert(DV % tile_A::J == 0, "bad DV");
1178
  static_assert(ncols % cols_per_warp == 0, "bad ncols");
1179
 
1180
+ const size_t nbytes_shared_KV_1stage = c::nbatch_fa * std::max(c::nbatch_K2 + 4, c::nbatch_V2 + 4) * sizeof(half2);
1181
+ const size_t nbytes_shared_KV_2stage = c::nbatch_fa * (c::nbatch_K2 + 4 + c::nbatch_V2 + 4) * sizeof(half2);
1182
+ const size_t nbytes_shared_Q = ncols * (DKQ/2 + 4) * sizeof(half2);
1183
+ const size_t nbytes_shared_mask = ncols1 * (c::nbatch_fa/2 + 4) * sizeof(half2);
1184
+ const size_t nbytes_shared_combine = nwarps*cols_per_warp * (nbatch_combine + 4) * sizeof(half2);
1185
 
1186
+ const size_t nbytes_shared_KV = nstages <= 1 ? nbytes_shared_KV_1stage : nbytes_shared_KV_2stage;
1187
 
1188
+ const size_t nbytes_shared_total = std::max(nbytes_shared_combine, c::Q_in_reg ?
1189
+ std::max(nbytes_shared_Q, nbytes_shared_KV + nbytes_shared_mask) :
1190
+ nbytes_shared_Q + nbytes_shared_KV + nbytes_shared_mask);
 
 
1191
 
1192
  float logit_softcap;
1193
  memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
 
1195
  fattn_kernel_t fattn_kernel;
1196
  if (logit_softcap == 0.0f) {
1197
  constexpr bool use_logit_softcap = false;
1198
+ fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap>;
1199
+
1200
+ #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
1201
+ static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
1202
+ if (!shared_memory_limit_raised[id]) {
1203
+ CUDA_CHECK(cudaFuncSetAttribute(fattn_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total));
1204
+ shared_memory_limit_raised[id] = true;
1205
+ }
1206
+ #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
1207
  } else {
1208
  constexpr bool use_logit_softcap = true;
1209
+ fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap>;
1210
+
1211
+ #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
1212
+ static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
1213
+ if (!shared_memory_limit_raised[id]) {
1214
+ CUDA_CHECK(cudaFuncSetAttribute(fattn_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total));
1215
+ shared_memory_limit_raised[id] = true;
1216
+ }
1217
+ #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
1218
  }
1219
 
1220
+ launch_fattn<DV, ncols1, ncols2>
1221
  (ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, FATTN_KQ_STRIDE, true, true, true);
1222
  }
1223
 
1224
 
1225
+ #define DECL_FATTN_MMA_F16_CASE(DKQ, DV, ncols1, ncols2) \
1226
+ template void ggml_cuda_flash_attn_ext_mma_f16_case \
1227
+ <DKQ, DV, ncols1, ncols2>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
1228
+
1229
+ #define DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(DKQ, DV, ncols) \
1230
+ extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/ 1, 1); \
1231
+ extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/ 2, 2); \
1232
+ extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/ 4, 4); \
1233
+ extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/ 8, 8); \
1234
+ extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/16, 16); \
1235
+
1236
+ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 64, 8)
1237
+ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 80, 8)
1238
+ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 96, 8)
1239
+ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 112, 8)
1240
+ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128, 8)
1241
+ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 8)
1242
+
1243
+ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 64, 16)
1244
+ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 80, 16)
1245
+ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 96, 16)
1246
+ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 112, 16)
1247
+ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128, 16)
1248
+ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 16)
1249
+
1250
+ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 64, 32)
1251
+ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 80, 32)
1252
+ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 96, 32)
1253
+ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 112, 32)
1254
+ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128, 32)
1255
+ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 32)
1256
+
1257
+ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 64, 64)
1258
+ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 80, 64)
1259
+ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 96, 64)
1260
+ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 112, 64)
1261
+ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128, 64)
1262
+ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 64)
1263
+
1264
+ // The number of viable configurations for Deepseek is very limited:
1265
+ extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16);
1266
+ extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16);
1267
+ extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16);
 
 
ggml/src/ggml-cuda/fattn-tile-f16.cu CHANGED
@@ -307,7 +307,7 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
307
  constexpr int nwarps = 8;
308
  constexpr size_t nbytes_shared = 0;
309
  fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, use_logit_softcap>;
310
- launch_fattn<D, cols_per_block, 1, -1>
311
  (ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F16, true, true, false);
312
  } break;
313
  case 128: {
@@ -315,7 +315,7 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
315
  constexpr int nwarps = 8;
316
  constexpr size_t nbytes_shared = 0;
317
  fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, use_logit_softcap>;
318
- launch_fattn<D, cols_per_block, 1, -1>
319
  (ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F16, true, true, false);
320
  } break;
321
  default: {
 
307
  constexpr int nwarps = 8;
308
  constexpr size_t nbytes_shared = 0;
309
  fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, use_logit_softcap>;
310
+ launch_fattn<D, cols_per_block, 1>
311
  (ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F16, true, true, false);
312
  } break;
313
  case 128: {
 
315
  constexpr int nwarps = 8;
316
  constexpr size_t nbytes_shared = 0;
317
  fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, use_logit_softcap>;
318
+ launch_fattn<D, cols_per_block, 1>
319
  (ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F16, true, true, false);
320
  } break;
321
  default: {
ggml/src/ggml-cuda/fattn-tile-f32.cu CHANGED
@@ -318,7 +318,7 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
318
  constexpr int nwarps = 8;
319
  constexpr size_t nbytes_shared = 0;
320
  fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, use_logit_softcap>;
321
- launch_fattn<D, cols_per_block, 1, -1>
322
  (ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F32, true, true, false);
323
  } break;
324
  case 128: {
@@ -326,7 +326,7 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
326
  constexpr int nwarps = 8;
327
  constexpr size_t nbytes_shared = 0;
328
  fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, use_logit_softcap>;
329
- launch_fattn<D, cols_per_block, 1, -1>
330
  (ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F32, true, true, false);
331
  } break;
332
  default: {
 
318
  constexpr int nwarps = 8;
319
  constexpr size_t nbytes_shared = 0;
320
  fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, use_logit_softcap>;
321
+ launch_fattn<D, cols_per_block, 1>
322
  (ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F32, true, true, false);
323
  } break;
324
  case 128: {
 
326
  constexpr int nwarps = 8;
327
  constexpr size_t nbytes_shared = 0;
328
  fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, use_logit_softcap>;
329
+ launch_fattn<D, cols_per_block, 1>
330
  (ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F32, true, true, false);
331
  } break;
332
  default: {
ggml/src/ggml-cuda/fattn-vec-f16.cuh CHANGED
@@ -315,7 +315,7 @@ void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx,
315
  constexpr bool need_f16_K = D != 128;
316
  constexpr bool need_f16_V = D != 128 && D != 64;
317
  constexpr size_t nbytes_shared = 0;
318
- launch_fattn<D, cols_per_block, 1, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false);
319
  }
320
 
321
  template <int D, ggml_type type_K, ggml_type type_V>
 
315
  constexpr bool need_f16_K = D != 128;
316
  constexpr bool need_f16_V = D != 128 && D != 64;
317
  constexpr size_t nbytes_shared = 0;
318
+ launch_fattn<D, cols_per_block, 1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false);
319
  }
320
 
321
  template <int D, ggml_type type_K, ggml_type type_V>
ggml/src/ggml-cuda/fattn-vec-f32.cuh CHANGED
@@ -310,7 +310,7 @@ void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx,
310
  constexpr bool need_f16_K = D != 128;
311
  constexpr bool need_f16_V = D != 128 && D != 64;
312
  constexpr size_t nbytes_shared = 0;
313
- launch_fattn<D, cols_per_block, 1, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false);
314
  }
315
 
316
  template <int D, ggml_type type_K, ggml_type type_V>
 
310
  constexpr bool need_f16_K = D != 128;
311
  constexpr bool need_f16_V = D != 128 && D != 64;
312
  constexpr size_t nbytes_shared = 0;
313
+ launch_fattn<D, cols_per_block, 1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false);
314
  }
315
 
316
  template <int D, ggml_type type_K, ggml_type type_V>
ggml/src/ggml-cuda/fattn-wmma-f16.cu CHANGED
@@ -490,7 +490,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
490
  fattn_kernel = flash_attn_ext_f16<
491
  D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), KQ_acc_t, use_logit_softcap>;
492
  }
493
- launch_fattn<D, cols_per_block, 1, -1>(ctx, dst, fattn_kernel, nwarps, 0, FATTN_KQ_STRIDE, true, true, false, warp_size);
494
  }
495
 
496
  void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
 
490
  fattn_kernel = flash_attn_ext_f16<
491
  D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), KQ_acc_t, use_logit_softcap>;
492
  }
493
+ launch_fattn<D, cols_per_block, 1>(ctx, dst, fattn_kernel, nwarps, 0, FATTN_KQ_STRIDE, true, true, false, warp_size);
494
  }
495
 
496
  void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml/src/ggml-cuda/fattn.cu CHANGED
@@ -8,58 +8,32 @@
8
  #include "fattn-wmma-f16.cuh"
9
  #include "fattn.cuh"
10
 
11
- template <int D, int ncols2>
12
  static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
13
  const ggml_tensor * Q = dst->src[0];
14
 
15
- if (Q->ne[1] <= 8/ncols2) {
16
- ggml_cuda_flash_attn_ext_mma_f16_case<D, 8/ncols2, ncols2>(ctx, dst);
17
- return;
 
 
18
  }
19
 
20
  if (Q->ne[1] <= 16/ncols2) {
21
- ggml_cuda_flash_attn_ext_mma_f16_case<D, 16/ncols2, ncols2>(ctx, dst);
22
  return;
23
  }
24
 
25
  if (Q->ne[1] <= 32/ncols2) {
26
- ggml_cuda_flash_attn_ext_mma_f16_case<D, 32/ncols2, ncols2>(ctx, dst);
27
  return;
28
  }
29
 
30
- ggml_cuda_flash_attn_ext_mma_f16_case<D, 64/ncols2, ncols2>(ctx, dst);
31
- }
32
-
33
- template <int ncols2>
34
- static void ggml_cuda_flash_attn_ext_mma_f16_switch_hs(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
35
- const ggml_tensor * Q = dst->src[0];
36
-
37
- switch (Q->ne[0]) {
38
- case 64:
39
- ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 64, ncols2>(ctx, dst);
40
- break;
41
- case 80:
42
- ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 80, ncols2>(ctx, dst);
43
- break;
44
- case 96:
45
- ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 96, ncols2>(ctx, dst);
46
- break;
47
- case 112:
48
- ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<112, ncols2>(ctx, dst);
49
- break;
50
- case 128:
51
- ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<128, ncols2>(ctx, dst);
52
- break;
53
- case 256:
54
- ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<256, ncols2>(ctx, dst);
55
- break;
56
- default:
57
- GGML_ABORT("fatal error");
58
- break;
59
- }
60
  }
61
 
62
- static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
 
63
  const ggml_tensor * KQV = dst;
64
  const ggml_tensor * Q = dst->src[0];
65
  const ggml_tensor * K = dst->src[1];
@@ -68,27 +42,79 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
68
  float max_bias = 0.0f;
69
  memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
70
 
71
- const float use_gqa_opt = mask && max_bias == 0.0f;
72
 
73
  GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
74
  const int gqa_ratio = Q->ne[2] / K->ne[2];
75
 
76
  if (use_gqa_opt && gqa_ratio % 8 == 0) {
77
- ggml_cuda_flash_attn_ext_mma_f16_switch_hs<8>(ctx, dst);
78
  return;
79
  }
80
 
81
- if (use_gqa_opt && gqa_ratio == 4) {
82
- ggml_cuda_flash_attn_ext_mma_f16_switch_hs<4>(ctx, dst);
83
  return;
84
  }
85
 
86
- if (use_gqa_opt && gqa_ratio == 2) {
87
- ggml_cuda_flash_attn_ext_mma_f16_switch_hs<2>(ctx, dst);
88
  return;
89
  }
90
 
91
- ggml_cuda_flash_attn_ext_mma_f16_switch_hs<1>(ctx, dst);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  }
93
 
94
  #define FATTN_VEC_F16_CASE(D, type_K, type_V) \
@@ -299,7 +325,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
299
  const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations
300
  const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16;
301
  const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && cc < GGML_CUDA_CC_ADA_LOVELACE && !mma_needs_data_conversion;
302
- const bool can_use_vector_kernel = Q->ne[0] % (2*warp_size) == 0;
303
  if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) {
304
  if (prec == GGML_PREC_DEFAULT) {
305
  ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
 
8
  #include "fattn-wmma-f16.cuh"
9
  #include "fattn.cuh"
10
 
11
+ template <int DKQ, int DV, int ncols2>
12
  static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
13
  const ggml_tensor * Q = dst->src[0];
14
 
15
+ if constexpr (ncols2 <= 8) {
16
+ if (Q->ne[1] <= 8/ncols2) {
17
+ ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 8/ncols2, ncols2>(ctx, dst);
18
+ return;
19
+ }
20
  }
21
 
22
  if (Q->ne[1] <= 16/ncols2) {
23
+ ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 16/ncols2, ncols2>(ctx, dst);
24
  return;
25
  }
26
 
27
  if (Q->ne[1] <= 32/ncols2) {
28
+ ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 32/ncols2, ncols2>(ctx, dst);
29
  return;
30
  }
31
 
32
+ ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 64/ncols2, ncols2>(ctx, dst);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  }
34
 
35
+ template <int DKQ, int DV>
36
+ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
37
  const ggml_tensor * KQV = dst;
38
  const ggml_tensor * Q = dst->src[0];
39
  const ggml_tensor * K = dst->src[1];
 
42
  float max_bias = 0.0f;
43
  memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
44
 
45
+ const bool use_gqa_opt = mask && max_bias == 0.0f;
46
 
47
  GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
48
  const int gqa_ratio = Q->ne[2] / K->ne[2];
49
 
50
  if (use_gqa_opt && gqa_ratio % 8 == 0) {
51
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 8>(ctx, dst);
52
  return;
53
  }
54
 
55
+ if (use_gqa_opt && gqa_ratio % 4 == 0) {
56
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 4>(ctx, dst);
57
  return;
58
  }
59
 
60
+ if (use_gqa_opt && gqa_ratio % 2 == 0) {
61
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 2>(ctx, dst);
62
  return;
63
  }
64
 
65
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 1>(ctx, dst);
66
+ }
67
+
68
+ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
69
+ const ggml_tensor * KQV = dst;
70
+ const ggml_tensor * Q = dst->src[0];
71
+ const ggml_tensor * K = dst->src[1];
72
+ const ggml_tensor * V = dst->src[2];
73
+ const ggml_tensor * mask = dst->src[3];
74
+
75
+ switch (Q->ne[0]) {
76
+ case 64:
77
+ GGML_ASSERT(V->ne[0] == 64);
78
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2< 64, 64>(ctx, dst);
79
+ break;
80
+ case 80:
81
+ GGML_ASSERT(V->ne[0] == 80);
82
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2< 80, 80>(ctx, dst);
83
+ break;
84
+ case 96:
85
+ GGML_ASSERT(V->ne[0] == 96);
86
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2< 96, 96>(ctx, dst);
87
+ break;
88
+ case 112:
89
+ GGML_ASSERT(V->ne[0] == 112);
90
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<112, 112>(ctx, dst);
91
+ break;
92
+ case 128:
93
+ GGML_ASSERT(V->ne[0] == 128);
94
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<128, 128>(ctx, dst);
95
+ break;
96
+ case 256:
97
+ GGML_ASSERT(V->ne[0] == 256);
98
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst);
99
+ break;
100
+ case 576: {
101
+ // For Deepseek, go straight to the ncols1 switch to avoid compiling unnecessary kernels.
102
+ GGML_ASSERT(V->ne[0] == 512);
103
+ float max_bias = 0.0f;
104
+ memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
105
+
106
+ const bool use_gqa_opt = mask && max_bias == 0.0f;
107
+ GGML_ASSERT(use_gqa_opt);
108
+
109
+ GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
110
+ const int gqa_ratio = Q->ne[2] / K->ne[2];
111
+ GGML_ASSERT(gqa_ratio % 16 == 0);
112
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
113
+ } break;
114
+ default:
115
+ GGML_ABORT("fatal error");
116
+ break;
117
+ }
118
  }
119
 
120
  #define FATTN_VEC_F16_CASE(D, type_K, type_V) \
 
325
  const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations
326
  const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16;
327
  const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && cc < GGML_CUDA_CC_ADA_LOVELACE && !mma_needs_data_conversion;
328
+ const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % (2*warp_size) == 0;
329
  if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) {
330
  if (prec == GGML_PREC_DEFAULT) {
331
  ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
ggml/src/ggml-cuda/ggml-cuda.cu CHANGED
@@ -3215,16 +3215,16 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3215
  return false;
3216
  #endif // FLASH_ATTN_AVAILABLE
3217
  if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
3218
- // different head sizes of K and V are not supported yet
3219
- return false;
 
 
 
 
3220
  }
3221
  if (op->src[0]->ne[0] == 192) {
3222
  return false;
3223
  }
3224
- if (op->src[0]->ne[0] == 576) {
3225
- // DeepSeek MLA
3226
- return false;
3227
- }
3228
  if (op->src[0]->ne[3] != 1) {
3229
  return false;
3230
  }
 
3215
  return false;
3216
  #endif // FLASH_ATTN_AVAILABLE
3217
  if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
3218
+ const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
3219
+ if (!new_mma_available(cc) || cc < GGML_CUDA_CC_AMPERE) {
3220
+ return false;
3221
+ }
3222
+ const int gqa_ratio = op->src[0]->ne[2] / op->src[1]->ne[2];
3223
+ return op->src[1]->ne[0] == 576 && op->src[2]->ne[0] == 512 && op->src[3] && gqa_ratio % 16 == 0;
3224
  }
3225
  if (op->src[0]->ne[0] == 192) {
3226
  return false;
3227
  }
 
 
 
 
3228
  if (op->src[0]->ne[3] != 1) {
3229
  return false;
3230
  }
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
+
3
+ #include "../fattn-mma-f16.cuh"
4
+
5
+ DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16);
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu CHANGED
@@ -2,9 +2,9 @@
2
 
3
  #include "../fattn-mma-f16.cuh"
4
 
5
- DECL_FATTN_MMA_F16_CASE(64, 1, 8);
6
- DECL_FATTN_MMA_F16_CASE(80, 1, 8);
7
- DECL_FATTN_MMA_F16_CASE(96, 1, 8);
8
- DECL_FATTN_MMA_F16_CASE(112, 1, 8);
9
- DECL_FATTN_MMA_F16_CASE(128, 1, 8);
10
- DECL_FATTN_MMA_F16_CASE(256, 1, 8);
 
2
 
3
  #include "../fattn-mma-f16.cuh"
4
 
5
+ DECL_FATTN_MMA_F16_CASE(64, 64, 1, 8);
6
+ DECL_FATTN_MMA_F16_CASE(80, 80, 1, 8);
7
+ DECL_FATTN_MMA_F16_CASE(96, 96, 1, 8);
8
+ DECL_FATTN_MMA_F16_CASE(112, 112, 1, 8);
9
+ DECL_FATTN_MMA_F16_CASE(128, 128, 1, 8);
10
+ DECL_FATTN_MMA_F16_CASE(256, 256, 1, 8);
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu CHANGED
@@ -2,9 +2,9 @@
2
 
3
  #include "../fattn-mma-f16.cuh"
4
 
5
- DECL_FATTN_MMA_F16_CASE(64, 16, 1);
6
- DECL_FATTN_MMA_F16_CASE(80, 16, 1);
7
- DECL_FATTN_MMA_F16_CASE(96, 16, 1);
8
- DECL_FATTN_MMA_F16_CASE(112, 16, 1);
9
- DECL_FATTN_MMA_F16_CASE(128, 16, 1);
10
- DECL_FATTN_MMA_F16_CASE(256, 16, 1);
 
2
 
3
  #include "../fattn-mma-f16.cuh"
4
 
5
+ DECL_FATTN_MMA_F16_CASE(64, 64, 16, 1);
6
+ DECL_FATTN_MMA_F16_CASE(80, 80, 16, 1);
7
+ DECL_FATTN_MMA_F16_CASE(96, 96, 16, 1);
8
+ DECL_FATTN_MMA_F16_CASE(112, 112, 16, 1);
9
+ DECL_FATTN_MMA_F16_CASE(128, 128, 16, 1);
10
+ DECL_FATTN_MMA_F16_CASE(256, 256, 16, 1);
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu CHANGED
@@ -2,9 +2,9 @@
2
 
3
  #include "../fattn-mma-f16.cuh"
4
 
5
- DECL_FATTN_MMA_F16_CASE(64, 16, 2);
6
- DECL_FATTN_MMA_F16_CASE(80, 16, 2);
7
- DECL_FATTN_MMA_F16_CASE(96, 16, 2);
8
- DECL_FATTN_MMA_F16_CASE(112, 16, 2);
9
- DECL_FATTN_MMA_F16_CASE(128, 16, 2);
10
- DECL_FATTN_MMA_F16_CASE(256, 16, 2);
 
2
 
3
  #include "../fattn-mma-f16.cuh"
4
 
5
+ DECL_FATTN_MMA_F16_CASE(64, 64, 16, 2);
6
+ DECL_FATTN_MMA_F16_CASE(80, 80, 16, 2);
7
+ DECL_FATTN_MMA_F16_CASE(96, 96, 16, 2);
8
+ DECL_FATTN_MMA_F16_CASE(112, 112, 16, 2);
9
+ DECL_FATTN_MMA_F16_CASE(128, 128, 16, 2);
10
+ DECL_FATTN_MMA_F16_CASE(256, 256, 16, 2);
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu CHANGED
@@ -2,9 +2,9 @@
2
 
3
  #include "../fattn-mma-f16.cuh"
4
 
5
- DECL_FATTN_MMA_F16_CASE(64, 16, 4);
6
- DECL_FATTN_MMA_F16_CASE(80, 16, 4);
7
- DECL_FATTN_MMA_F16_CASE(96, 16, 4);
8
- DECL_FATTN_MMA_F16_CASE(112, 16, 4);
9
- DECL_FATTN_MMA_F16_CASE(128, 16, 4);
10
- DECL_FATTN_MMA_F16_CASE(256, 16, 4);
 
2
 
3
  #include "../fattn-mma-f16.cuh"
4
 
5
+ DECL_FATTN_MMA_F16_CASE(64, 64, 16, 4);
6
+ DECL_FATTN_MMA_F16_CASE(80, 80, 16, 4);
7
+ DECL_FATTN_MMA_F16_CASE(96, 96, 16, 4);
8
+ DECL_FATTN_MMA_F16_CASE(112, 112, 16, 4);
9
+ DECL_FATTN_MMA_F16_CASE(128, 128, 16, 4);
10
+ DECL_FATTN_MMA_F16_CASE(256, 256, 16, 4);
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
+
3
+ #include "../fattn-mma-f16.cuh"
4
+
5
+ DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16);
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu CHANGED
@@ -2,9 +2,9 @@
2
 
3
  #include "../fattn-mma-f16.cuh"
4
 
5
- DECL_FATTN_MMA_F16_CASE(64, 2, 4);
6
- DECL_FATTN_MMA_F16_CASE(80, 2, 4);
7
- DECL_FATTN_MMA_F16_CASE(96, 2, 4);
8
- DECL_FATTN_MMA_F16_CASE(112, 2, 4);
9
- DECL_FATTN_MMA_F16_CASE(128, 2, 4);
10
- DECL_FATTN_MMA_F16_CASE(256, 2, 4);
 
2
 
3
  #include "../fattn-mma-f16.cuh"
4
 
5
+ DECL_FATTN_MMA_F16_CASE(64, 64, 2, 4);
6
+ DECL_FATTN_MMA_F16_CASE(80, 80, 2, 4);
7
+ DECL_FATTN_MMA_F16_CASE(96, 96, 2, 4);
8
+ DECL_FATTN_MMA_F16_CASE(112, 112, 2, 4);
9
+ DECL_FATTN_MMA_F16_CASE(128, 128, 2, 4);
10
+ DECL_FATTN_MMA_F16_CASE(256, 256, 2, 4);
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu CHANGED
@@ -2,9 +2,9 @@
2
 
3
  #include "../fattn-mma-f16.cuh"
4
 
5
- DECL_FATTN_MMA_F16_CASE(64, 2, 8);
6
- DECL_FATTN_MMA_F16_CASE(80, 2, 8);
7
- DECL_FATTN_MMA_F16_CASE(96, 2, 8);
8
- DECL_FATTN_MMA_F16_CASE(112, 2, 8);
9
- DECL_FATTN_MMA_F16_CASE(128, 2, 8);
10
- DECL_FATTN_MMA_F16_CASE(256, 2, 8);
 
2
 
3
  #include "../fattn-mma-f16.cuh"
4
 
5
+ DECL_FATTN_MMA_F16_CASE(64, 64, 2, 8);
6
+ DECL_FATTN_MMA_F16_CASE(80, 80, 2, 8);
7
+ DECL_FATTN_MMA_F16_CASE(96, 96, 2, 8);
8
+ DECL_FATTN_MMA_F16_CASE(112, 112, 2, 8);
9
+ DECL_FATTN_MMA_F16_CASE(128, 128, 2, 8);
10
+ DECL_FATTN_MMA_F16_CASE(256, 256, 2, 8);
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu CHANGED
@@ -2,9 +2,9 @@
2
 
3
  #include "../fattn-mma-f16.cuh"
4
 
5
- DECL_FATTN_MMA_F16_CASE(64, 32, 1);
6
- DECL_FATTN_MMA_F16_CASE(80, 32, 1);
7
- DECL_FATTN_MMA_F16_CASE(96, 32, 1);
8
- DECL_FATTN_MMA_F16_CASE(112, 32, 1);
9
- DECL_FATTN_MMA_F16_CASE(128, 32, 1);
10
- DECL_FATTN_MMA_F16_CASE(256, 32, 1);
 
2
 
3
  #include "../fattn-mma-f16.cuh"
4
 
5
+ DECL_FATTN_MMA_F16_CASE(64, 64, 32, 1);
6
+ DECL_FATTN_MMA_F16_CASE(80, 80, 32, 1);
7
+ DECL_FATTN_MMA_F16_CASE(96, 96, 32, 1);
8
+ DECL_FATTN_MMA_F16_CASE(112, 112, 32, 1);
9
+ DECL_FATTN_MMA_F16_CASE(128, 128, 32, 1);
10
+ DECL_FATTN_MMA_F16_CASE(256, 256, 32, 1);
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu CHANGED
@@ -2,9 +2,9 @@
2
 
3
  #include "../fattn-mma-f16.cuh"
4
 
5
- DECL_FATTN_MMA_F16_CASE(64, 32, 2);
6
- DECL_FATTN_MMA_F16_CASE(80, 32, 2);
7
- DECL_FATTN_MMA_F16_CASE(96, 32, 2);
8
- DECL_FATTN_MMA_F16_CASE(112, 32, 2);
9
- DECL_FATTN_MMA_F16_CASE(128, 32, 2);
10
- DECL_FATTN_MMA_F16_CASE(256, 32, 2);
 
2
 
3
  #include "../fattn-mma-f16.cuh"
4
 
5
+ DECL_FATTN_MMA_F16_CASE(64, 64, 32, 2);
6
+ DECL_FATTN_MMA_F16_CASE(80, 80, 32, 2);
7
+ DECL_FATTN_MMA_F16_CASE(96, 96, 32, 2);
8
+ DECL_FATTN_MMA_F16_CASE(112, 112, 32, 2);
9
+ DECL_FATTN_MMA_F16_CASE(128, 128, 32, 2);
10
+ DECL_FATTN_MMA_F16_CASE(256, 256, 32, 2);
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
+
3
+ #include "../fattn-mma-f16.cuh"
4
+
5
+ DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16);
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu CHANGED
@@ -2,9 +2,9 @@
2
 
3
  #include "../fattn-mma-f16.cuh"
4
 
5
- DECL_FATTN_MMA_F16_CASE(64, 4, 2);
6
- DECL_FATTN_MMA_F16_CASE(80, 4, 2);
7
- DECL_FATTN_MMA_F16_CASE(96, 4, 2);
8
- DECL_FATTN_MMA_F16_CASE(112, 4, 2);
9
- DECL_FATTN_MMA_F16_CASE(128, 4, 2);
10
- DECL_FATTN_MMA_F16_CASE(256, 4, 2);
 
2
 
3
  #include "../fattn-mma-f16.cuh"
4
 
5
+ DECL_FATTN_MMA_F16_CASE(64, 64, 4, 2);
6
+ DECL_FATTN_MMA_F16_CASE(80, 80, 4, 2);
7
+ DECL_FATTN_MMA_F16_CASE(96, 96, 4, 2);
8
+ DECL_FATTN_MMA_F16_CASE(112, 112, 4, 2);
9
+ DECL_FATTN_MMA_F16_CASE(128, 128, 4, 2);
10
+ DECL_FATTN_MMA_F16_CASE(256, 256, 4, 2);
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu CHANGED
@@ -2,9 +2,9 @@
2
 
3
  #include "../fattn-mma-f16.cuh"
4
 
5
- DECL_FATTN_MMA_F16_CASE(64, 4, 4);
6
- DECL_FATTN_MMA_F16_CASE(80, 4, 4);
7
- DECL_FATTN_MMA_F16_CASE(96, 4, 4);
8
- DECL_FATTN_MMA_F16_CASE(112, 4, 4);
9
- DECL_FATTN_MMA_F16_CASE(128, 4, 4);
10
- DECL_FATTN_MMA_F16_CASE(256, 4, 4);
 
2
 
3
  #include "../fattn-mma-f16.cuh"
4
 
5
+ DECL_FATTN_MMA_F16_CASE(64, 64, 4, 4);
6
+ DECL_FATTN_MMA_F16_CASE(80, 80, 4, 4);
7
+ DECL_FATTN_MMA_F16_CASE(96, 96, 4, 4);
8
+ DECL_FATTN_MMA_F16_CASE(112, 112, 4, 4);
9
+ DECL_FATTN_MMA_F16_CASE(128, 128, 4, 4);
10
+ DECL_FATTN_MMA_F16_CASE(256, 256, 4, 4);
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu CHANGED
@@ -2,9 +2,9 @@
2
 
3
  #include "../fattn-mma-f16.cuh"
4
 
5
- DECL_FATTN_MMA_F16_CASE(64, 4, 8);
6
- DECL_FATTN_MMA_F16_CASE(80, 4, 8);
7
- DECL_FATTN_MMA_F16_CASE(96, 4, 8);
8
- DECL_FATTN_MMA_F16_CASE(112, 4, 8);
9
- DECL_FATTN_MMA_F16_CASE(128, 4, 8);
10
- DECL_FATTN_MMA_F16_CASE(256, 4, 8);
 
2
 
3
  #include "../fattn-mma-f16.cuh"
4
 
5
+ DECL_FATTN_MMA_F16_CASE(64, 64, 4, 8);
6
+ DECL_FATTN_MMA_F16_CASE(80, 80, 4, 8);
7
+ DECL_FATTN_MMA_F16_CASE(96, 96, 4, 8);
8
+ DECL_FATTN_MMA_F16_CASE(112, 112, 4, 8);
9
+ DECL_FATTN_MMA_F16_CASE(128, 128, 4, 8);
10
+ DECL_FATTN_MMA_F16_CASE(256, 256, 4, 8);
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu CHANGED
@@ -2,9 +2,9 @@
2
 
3
  #include "../fattn-mma-f16.cuh"
4
 
5
- DECL_FATTN_MMA_F16_CASE(64, 64, 1);
6
- DECL_FATTN_MMA_F16_CASE(80, 64, 1);
7
- DECL_FATTN_MMA_F16_CASE(96, 64, 1);
8
- DECL_FATTN_MMA_F16_CASE(112, 64, 1);
9
- DECL_FATTN_MMA_F16_CASE(128, 64, 1);
10
- DECL_FATTN_MMA_F16_CASE(256, 64, 1);
 
2
 
3
  #include "../fattn-mma-f16.cuh"
4
 
5
+ DECL_FATTN_MMA_F16_CASE(64, 64, 64, 1);
6
+ DECL_FATTN_MMA_F16_CASE(80, 80, 64, 1);
7
+ DECL_FATTN_MMA_F16_CASE(96, 96, 64, 1);
8
+ DECL_FATTN_MMA_F16_CASE(112, 112, 64, 1);
9
+ DECL_FATTN_MMA_F16_CASE(128, 128, 64, 1);
10
+ DECL_FATTN_MMA_F16_CASE(256, 256, 64, 1);
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu CHANGED
@@ -2,9 +2,9 @@
2
 
3
  #include "../fattn-mma-f16.cuh"
4
 
5
- DECL_FATTN_MMA_F16_CASE(64, 8, 1);
6
- DECL_FATTN_MMA_F16_CASE(80, 8, 1);
7
- DECL_FATTN_MMA_F16_CASE(96, 8, 1);
8
- DECL_FATTN_MMA_F16_CASE(112, 8, 1);
9
- DECL_FATTN_MMA_F16_CASE(128, 8, 1);
10
- DECL_FATTN_MMA_F16_CASE(256, 8, 1);
 
2
 
3
  #include "../fattn-mma-f16.cuh"
4
 
5
+ DECL_FATTN_MMA_F16_CASE(64, 64, 8, 1);
6
+ DECL_FATTN_MMA_F16_CASE(80, 80, 8, 1);
7
+ DECL_FATTN_MMA_F16_CASE(96, 96, 8, 1);
8
+ DECL_FATTN_MMA_F16_CASE(112, 112, 8, 1);
9
+ DECL_FATTN_MMA_F16_CASE(128, 128, 8, 1);
10
+ DECL_FATTN_MMA_F16_CASE(256, 256, 8, 1);
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu CHANGED
@@ -2,9 +2,9 @@
2
 
3
  #include "../fattn-mma-f16.cuh"
4
 
5
- DECL_FATTN_MMA_F16_CASE(64, 8, 2);
6
- DECL_FATTN_MMA_F16_CASE(80, 8, 2);
7
- DECL_FATTN_MMA_F16_CASE(96, 8, 2);
8
- DECL_FATTN_MMA_F16_CASE(112, 8, 2);
9
- DECL_FATTN_MMA_F16_CASE(128, 8, 2);
10
- DECL_FATTN_MMA_F16_CASE(256, 8, 2);
 
2
 
3
  #include "../fattn-mma-f16.cuh"
4
 
5
+ DECL_FATTN_MMA_F16_CASE(64, 64, 8, 2);
6
+ DECL_FATTN_MMA_F16_CASE(80, 80, 8, 2);
7
+ DECL_FATTN_MMA_F16_CASE(96, 96, 8, 2);
8
+ DECL_FATTN_MMA_F16_CASE(112, 112, 8, 2);
9
+ DECL_FATTN_MMA_F16_CASE(128, 128, 8, 2);
10
+ DECL_FATTN_MMA_F16_CASE(256, 256, 8, 2);
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu CHANGED
@@ -2,9 +2,9 @@
2
 
3
  #include "../fattn-mma-f16.cuh"
4
 
5
- DECL_FATTN_MMA_F16_CASE(64, 8, 4);
6
- DECL_FATTN_MMA_F16_CASE(80, 8, 4);
7
- DECL_FATTN_MMA_F16_CASE(96, 8, 4);
8
- DECL_FATTN_MMA_F16_CASE(112, 8, 4);
9
- DECL_FATTN_MMA_F16_CASE(128, 8, 4);
10
- DECL_FATTN_MMA_F16_CASE(256, 8, 4);
 
2
 
3
  #include "../fattn-mma-f16.cuh"
4
 
5
+ DECL_FATTN_MMA_F16_CASE(64, 64, 8, 4);
6
+ DECL_FATTN_MMA_F16_CASE(80, 80, 8, 4);
7
+ DECL_FATTN_MMA_F16_CASE(96, 96, 8, 4);
8
+ DECL_FATTN_MMA_F16_CASE(112, 112, 8, 4);
9
+ DECL_FATTN_MMA_F16_CASE(128, 128, 8, 4);
10
+ DECL_FATTN_MMA_F16_CASE(256, 256, 8, 4);
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu CHANGED
@@ -2,9 +2,9 @@
2
 
3
  #include "../fattn-mma-f16.cuh"
4
 
5
- DECL_FATTN_MMA_F16_CASE(64, 8, 8);
6
- DECL_FATTN_MMA_F16_CASE(80, 8, 8);
7
- DECL_FATTN_MMA_F16_CASE(96, 8, 8);
8
- DECL_FATTN_MMA_F16_CASE(112, 8, 8);
9
- DECL_FATTN_MMA_F16_CASE(128, 8, 8);
10
- DECL_FATTN_MMA_F16_CASE(256, 8, 8);
 
2
 
3
  #include "../fattn-mma-f16.cuh"
4
 
5
+ DECL_FATTN_MMA_F16_CASE(64, 64, 8, 8);
6
+ DECL_FATTN_MMA_F16_CASE(80, 80, 8, 8);
7
+ DECL_FATTN_MMA_F16_CASE(96, 96, 8, 8);
8
+ DECL_FATTN_MMA_F16_CASE(112, 112, 8, 8);
9
+ DECL_FATTN_MMA_F16_CASE(128, 128, 8, 8);
10
+ DECL_FATTN_MMA_F16_CASE(256, 256, 8, 8);
ggml/src/ggml-cuda/template-instances/generate_cu_files.py CHANGED
@@ -18,7 +18,7 @@ SOURCE_FATTN_MMA_START = """// This file has been autogenerated by generate_cu_f
18
 
19
  """
20
 
21
- SOURCE_FATTN_MMA_CASE = "DECL_FATTN_MMA_F16_CASE({head_size}, {ncols1}, {ncols2});\n"
22
 
23
  TYPES_MMQ = [
24
  "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0",
@@ -57,18 +57,21 @@ for vkq_size in [16, 32]:
57
  with open(f"fattn-vec-f{vkq_size}-instance-hs{head_size}-{get_short_name(type_k)}-{get_short_name(type_v)}.cu", "w") as f:
58
  f.write(SOURCE_FATTN_VEC.format(vkq_size=vkq_size, head_size=head_size, type_k=type_k, type_v=type_v))
59
 
60
- for ncols in [8, 16, 32, 64, 128]:
61
- for ncols2 in [1, 2, 4, 8]:
 
 
62
  ncols1 = ncols // ncols2
63
- if ncols == 128:
64
- continue # Too much register pressure.
65
  with open(f"fattn-mma-f16-instance-ncols1_{ncols1}-ncols2_{ncols2}.cu", "w") as f:
66
  f.write(SOURCE_FATTN_MMA_START)
67
 
68
- for head_size in [64, 80, 96, 112, 128, 256]:
69
- if ncols == 128 and head_size == 256:
70
- continue # Needs too much shared memory.
71
- f.write(SOURCE_FATTN_MMA_CASE.format(ncols1=ncols1, ncols2=ncols2, head_size=head_size))
 
 
 
72
 
73
  for type in TYPES_MMQ:
74
  with open(f"mmq-instance-{get_short_name(type)}.cu", "w") as f:
 
18
 
19
  """
20
 
21
+ SOURCE_FATTN_MMA_CASE = "DECL_FATTN_MMA_F16_CASE({head_size_kq}, {head_size_v}, {ncols1}, {ncols2});\n"
22
 
23
  TYPES_MMQ = [
24
  "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0",
 
57
  with open(f"fattn-vec-f{vkq_size}-instance-hs{head_size}-{get_short_name(type_k)}-{get_short_name(type_v)}.cu", "w") as f:
58
  f.write(SOURCE_FATTN_VEC.format(vkq_size=vkq_size, head_size=head_size, type_k=type_k, type_v=type_v))
59
 
60
+ for ncols in [8, 16, 32, 64]:
61
+ for ncols2 in [1, 2, 4, 8, 16]:
62
+ if ncols2 > ncols:
63
+ continue
64
  ncols1 = ncols // ncols2
 
 
65
  with open(f"fattn-mma-f16-instance-ncols1_{ncols1}-ncols2_{ncols2}.cu", "w") as f:
66
  f.write(SOURCE_FATTN_MMA_START)
67
 
68
+ for head_size_kq in [64, 80, 96, 112, 128, 256, 576]:
69
+ if head_size_kq != 576 and ncols2 == 16:
70
+ continue
71
+ if head_size_kq == 576 and ncols2 != 16:
72
+ continue
73
+ head_size_v = head_size_kq if head_size_kq != 576 else 512
74
+ f.write(SOURCE_FATTN_MMA_CASE.format(ncols1=ncols1, ncols2=ncols2, head_size_kq=head_size_kq, head_size_v=head_size_v))
75
 
76
  for type in TYPES_MMQ:
77
  with open(f"mmq-instance-{get_short_name(type)}.cu", "w") as f: