JohannesGaessler commited on
Commit
10ac92f
·
1 Parent(s): 2746afd

CUDA: fix overflow in FA, tune performance (llama/14840)

Browse files
ggml/src/ggml-cuda/fattn-common.cuh CHANGED
@@ -23,33 +23,13 @@ typedef void (* fattn_kernel_t)(
23
  const float m1,
24
  const uint32_t n_head_log2,
25
  const float logit_softcap,
26
- const int ne00,
27
- const int ne01,
28
- const int ne02,
29
- const int ne03,
30
- const int ne10,
31
- const int ne11,
32
- const int ne12,
33
- const int ne13,
34
- const int ne31,
35
- const int ne32,
36
- const int ne33,
37
- const int nb31,
38
- const int nb32,
39
- const int nb33,
40
- const int nb01,
41
- const int nb02,
42
- const int nb03,
43
- const int nb11,
44
- const int nb12,
45
- const int nb13,
46
- const int nb21,
47
- const int nb22,
48
- const int nb23,
49
- const int ne0,
50
- const int ne1,
51
- const int ne2,
52
- const int ne3);
53
 
54
  typedef half (*vec_dot_KQ_f16_t)(
55
  const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds);
@@ -892,14 +872,11 @@ void launch_fattn(
892
  mask ? ((const char *) mask->data) : nullptr,
893
  !stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
894
  scale, max_bias, m0, m1, n_head_log2, logit_softcap,
895
- Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
896
- K->ne[0], K->ne[1], K->ne[2], K->ne[3],
897
- mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0,
898
- mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0, mask ? mask->nb[3] : 0,
899
- Q->nb[1], Q->nb[2], Q->nb[3],
900
- nb11, nb12, nb13,
901
  nb21, nb22, nb23,
902
- KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
 
903
  );
904
  CUDA_CHECK(cudaGetLastError());
905
 
 
23
  const float m1,
24
  const uint32_t n_head_log2,
25
  const float logit_softcap,
26
+ const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
27
+ const int32_t nb01, const int32_t nb02, const int32_t nb03,
28
+ const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
29
+ const int32_t nb11, const int32_t nb12, const int64_t nb13,
30
+ const int32_t nb21, const int32_t nb22, const int64_t nb23,
31
+ const int32_t ne31, const int32_t ne32, const int32_t ne33,
32
+ const int32_t nb31, const int32_t nb32, const int64_t nb33);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  typedef half (*vec_dot_KQ_f16_t)(
35
  const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds);
 
872
  mask ? ((const char *) mask->data) : nullptr,
873
  !stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
874
  scale, max_bias, m0, m1, n_head_log2, logit_softcap,
875
+ Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], Q->nb[1], Q->nb[2], Q->nb[3],
876
+ K->ne[0], K->ne[1], K->ne[2], K->ne[3], nb11, nb12, nb13,
 
 
 
 
877
  nb21, nb22, nb23,
878
+ mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0,
879
+ mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0, mask ? mask->nb[3] : 0
880
  );
881
  CUDA_CHECK(cudaGetLastError());
882
 
ggml/src/ggml-cuda/fattn-mma-f16.cuh CHANGED
@@ -408,7 +408,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
408
  const int stride_K,
409
  const int stride_V,
410
  const int stride_mask,
411
- const int jt,
412
  half2 * const __restrict__ tile_Q,
413
  half2 * const __restrict__ tile_K,
414
  half2 * const __restrict__ tile_V,
@@ -455,7 +454,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
455
  cp_async_wait_all();
456
  __syncthreads();
457
  flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
458
- (V_h2 + k_VKQ_0*stride_V, tile_V, nbatch_V2, stride_V);
459
  } else {
460
  constexpr bool use_cp_async = nstages == 1;
461
  if (ncols2 > 1 || mask_h2) {
@@ -471,7 +470,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
471
  if (nstages <= 1) {
472
  constexpr bool use_cp_async = nstages == 1;
473
  flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
474
- (K_h2 + k_VKQ_0*stride_K + k0_start, tile_K, k0_diff, stride_K);
475
  if (use_cp_async) {
476
  cp_async_wait_all();
477
  }
@@ -715,7 +714,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
715
  (mask_h2 + (k_VKQ_0 + c::nbatch_fa)/2, tile_mask, stride_mask);
716
  }
717
  flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
718
- (K_h2 + (k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K);
719
  }
720
  }
721
 
@@ -732,7 +731,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
732
  if (nstages <= 1 && i0_start < reusable_cutoff) {
733
  constexpr bool use_cp_async = nstages == 1;
734
  flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
735
- (V_h2 + k_VKQ_0*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V);
736
  if (use_cp_async) {
737
  cp_async_wait_all();
738
  }
@@ -771,8 +770,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
771
  GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup);
772
  GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap);
773
  GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_K); GGML_UNUSED(stride_V);
774
- GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K);
775
- GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K);
776
  GGML_UNUSED(tile_V); GGML_UNUSED(tile_mask); GGML_UNUSED(Q_B);
777
  GGML_UNUSED(VKQ_C); GGML_UNUSED(KQ_max); GGML_UNUSED(KQ_rowsum);
778
  GGML_UNUSED(kb0); GGML_UNUSED(tile_Q);
@@ -920,7 +918,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
920
  (mask_h2 + kb0_start*c::nbatch_fa/2, tile_mask, stride_mask);
921
  }
922
  flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
923
- (K_h2 + kb0_start*c::nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K);
924
  }
925
 
926
  // Iterate over ne11 == previous tokens:
@@ -928,13 +926,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
928
  constexpr bool last_iter = false;
929
  flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
930
  (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
931
- ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
932
  }
933
  { // kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
934
  constexpr bool last_iter = true;
935
  flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
936
  (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
937
- ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1);
938
  }
939
 
940
  // With multi-stage loading there is no __syncthreads at the end of the iter,
@@ -1214,33 +1212,13 @@ static __global__ void flash_attn_ext_f16(
1214
  const float m1,
1215
  const uint32_t n_head_log2,
1216
  const float logit_softcap,
1217
- const int ne00,
1218
- const int ne01,
1219
- const int ne02,
1220
- const int ne03,
1221
- const int ne10,
1222
- const int ne11,
1223
- const int ne12,
1224
- const int ne13,
1225
- const int ne31,
1226
- const int ne32,
1227
- const int ne33,
1228
- const int nb31,
1229
- const int nb32,
1230
- const int nb33,
1231
- const int nb01,
1232
- const int nb02,
1233
- const int nb03,
1234
- const int nb11,
1235
- const int nb12,
1236
- const int nb13,
1237
- const int nb21,
1238
- const int nb22,
1239
- const int nb23,
1240
- const int ne0,
1241
- const int ne1,
1242
- const int ne2,
1243
- const int ne3) {
1244
  #if defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
1245
 
1246
  // Skip unused kernel variants for faster compilation:
@@ -1359,8 +1337,7 @@ static __global__ void flash_attn_ext_f16(
1359
  GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
1360
  GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
1361
  GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21);
1362
- GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
1363
- GGML_UNUSED(ne2); GGML_UNUSED(ne3);
1364
  NO_DEVICE_CODE;
1365
  #endif // defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
1366
  }
 
408
  const int stride_K,
409
  const int stride_V,
410
  const int stride_mask,
 
411
  half2 * const __restrict__ tile_Q,
412
  half2 * const __restrict__ tile_K,
413
  half2 * const __restrict__ tile_V,
 
454
  cp_async_wait_all();
455
  __syncthreads();
456
  flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
457
+ (V_h2 + int64_t(k_VKQ_0)*stride_V, tile_V, nbatch_V2, stride_V);
458
  } else {
459
  constexpr bool use_cp_async = nstages == 1;
460
  if (ncols2 > 1 || mask_h2) {
 
470
  if (nstages <= 1) {
471
  constexpr bool use_cp_async = nstages == 1;
472
  flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
473
+ (K_h2 + int64_t(k_VKQ_0)*stride_K + k0_start, tile_K, k0_diff, stride_K);
474
  if (use_cp_async) {
475
  cp_async_wait_all();
476
  }
 
714
  (mask_h2 + (k_VKQ_0 + c::nbatch_fa)/2, tile_mask, stride_mask);
715
  }
716
  flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
717
+ (K_h2 + int64_t(k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K);
718
  }
719
  }
720
 
 
731
  if (nstages <= 1 && i0_start < reusable_cutoff) {
732
  constexpr bool use_cp_async = nstages == 1;
733
  flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
734
+ (V_h2 + int64_t(k_VKQ_0)*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V);
735
  if (use_cp_async) {
736
  cp_async_wait_all();
737
  }
 
770
  GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup);
771
  GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap);
772
  GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_K); GGML_UNUSED(stride_V);
773
+ GGML_UNUSED(stride_mask); GGML_UNUSED(tile_K);
 
774
  GGML_UNUSED(tile_V); GGML_UNUSED(tile_mask); GGML_UNUSED(Q_B);
775
  GGML_UNUSED(VKQ_C); GGML_UNUSED(KQ_max); GGML_UNUSED(KQ_rowsum);
776
  GGML_UNUSED(kb0); GGML_UNUSED(tile_Q);
 
918
  (mask_h2 + kb0_start*c::nbatch_fa/2, tile_mask, stride_mask);
919
  }
920
  flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
921
+ (K_h2 + int64_t(kb0_start)*c::nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K);
922
  }
923
 
924
  // Iterate over ne11 == previous tokens:
 
926
  constexpr bool last_iter = false;
927
  flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
928
  (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
929
+ ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
930
  }
931
  { // kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
932
  constexpr bool last_iter = true;
933
  flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
934
  (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
935
+ ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1);
936
  }
937
 
938
  // With multi-stage loading there is no __syncthreads at the end of the iter,
 
1212
  const float m1,
1213
  const uint32_t n_head_log2,
1214
  const float logit_softcap,
1215
+ const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
1216
+ const int32_t nb01, const int32_t nb02, const int32_t nb03,
1217
+ const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
1218
+ const int32_t nb11, const int32_t nb12, const int64_t nb13,
1219
+ const int32_t nb21, const int32_t nb22, const int64_t nb23,
1220
+ const int32_t ne31, const int32_t ne32, const int32_t ne33,
1221
+ const int32_t nb31, const int32_t nb32, const int64_t nb33) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1222
  #if defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
1223
 
1224
  // Skip unused kernel variants for faster compilation:
 
1337
  GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
1338
  GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
1339
  GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21);
1340
+ GGML_UNUSED(nb22); GGML_UNUSED(nb23);
 
1341
  NO_DEVICE_CODE;
1342
  #endif // defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
1343
  }
ggml/src/ggml-cuda/fattn-tile-f16.cu CHANGED
@@ -21,33 +21,13 @@ static __global__ void flash_attn_tile_ext_f16(
21
  const float m1,
22
  const uint32_t n_head_log2,
23
  const float logit_softcap,
24
- const int ne00,
25
- const int ne01,
26
- const int ne02,
27
- const int ne03,
28
- const int ne10,
29
- const int ne11,
30
- const int ne12,
31
- const int ne13,
32
- const int ne31,
33
- const int ne32,
34
- const int ne33,
35
- const int nb31,
36
- const int nb32,
37
- const int nb33,
38
- const int nb01,
39
- const int nb02,
40
- const int nb03,
41
- const int nb11,
42
- const int nb12,
43
- const int nb13,
44
- const int nb21,
45
- const int nb22,
46
- const int nb23,
47
- const int ne0,
48
- const int ne1,
49
- const int ne2,
50
- const int ne3) {
51
  #if defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
52
 
53
  // Skip unused kernel variants for faster compilation:
@@ -127,7 +107,7 @@ static __global__ void flash_attn_tile_ext_f16(
127
  for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
128
  const int k_KQ = k_KQ_0 + threadIdx.x;
129
 
130
- KV_tmp[i_KQ][k_KQ] = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ];
131
  }
132
  }
133
 
@@ -221,7 +201,7 @@ static __global__ void flash_attn_tile_ext_f16(
221
  for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
222
  const int i = i0 + threadIdx.x;
223
 
224
- KV_tmp[k][i] = V_h2[(k_VKQ_0 + k)*stride_KV2 + i];
225
  }
226
  }
227
 
@@ -300,8 +280,7 @@ static __global__ void flash_attn_tile_ext_f16(
300
  GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
301
  GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
302
  GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
303
- GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
304
- GGML_UNUSED(ne2); GGML_UNUSED(ne3);
305
  NO_DEVICE_CODE;
306
  #endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
307
  }
 
21
  const float m1,
22
  const uint32_t n_head_log2,
23
  const float logit_softcap,
24
+ const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
25
+ const int32_t nb01, const int32_t nb02, const int32_t nb03,
26
+ const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
27
+ const int32_t nb11, const int32_t nb12, const int64_t nb13,
28
+ const int32_t nb21, const int32_t nb22, const int64_t nb23,
29
+ const int32_t ne31, const int32_t ne32, const int32_t ne33,
30
+ const int32_t nb31, const int32_t nb32, const int64_t nb33) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  #if defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
32
 
33
  // Skip unused kernel variants for faster compilation:
 
107
  for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
108
  const int k_KQ = k_KQ_0 + threadIdx.x;
109
 
110
+ KV_tmp[i_KQ][k_KQ] = K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ];
111
  }
112
  }
113
 
 
201
  for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
202
  const int i = i0 + threadIdx.x;
203
 
204
+ KV_tmp[k][i] = V_h2[int64_t(k_VKQ_0 + k)*stride_KV2 + i];
205
  }
206
  }
207
 
 
280
  GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
281
  GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
282
  GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
283
+ GGML_UNUSED(nb23);
 
284
  NO_DEVICE_CODE;
285
  #endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
286
  }
ggml/src/ggml-cuda/fattn-tile-f32.cu CHANGED
@@ -21,33 +21,13 @@ static __global__ void flash_attn_tile_ext_f32(
21
  const float m1,
22
  const uint32_t n_head_log2,
23
  const float logit_softcap,
24
- const int ne00,
25
- const int ne01,
26
- const int ne02,
27
- const int ne03,
28
- const int ne10,
29
- const int ne11,
30
- const int ne12,
31
- const int ne13,
32
- const int ne31,
33
- const int ne32,
34
- const int ne33,
35
- const int nb31,
36
- const int nb32,
37
- const int nb33,
38
- const int nb01,
39
- const int nb02,
40
- const int nb03,
41
- const int nb11,
42
- const int nb12,
43
- const int nb13,
44
- const int nb21,
45
- const int nb22,
46
- const int nb23,
47
- const int ne0,
48
- const int ne1,
49
- const int ne2,
50
- const int ne3) {
51
  #ifdef FLASH_ATTN_AVAILABLE
52
 
53
  // Skip unused kernel variants for faster compilation:
@@ -66,8 +46,7 @@ static __global__ void flash_attn_tile_ext_f32(
66
  GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
67
  GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
68
  GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
69
- GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
70
- GGML_UNUSED(ne2); GGML_UNUSED(ne3);
71
  NO_DEVICE_CODE;
72
  return;
73
  }
@@ -135,7 +114,7 @@ static __global__ void flash_attn_tile_ext_f32(
135
 
136
  #pragma unroll
137
  for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 2*WARP_SIZE) {
138
- const half2 tmp = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + threadIdx.x];
139
  KV_tmp[i_KQ][k_KQ_0 + 0*WARP_SIZE + threadIdx.x] = __low2float(tmp);
140
  KV_tmp[i_KQ][k_KQ_0 + 1*WARP_SIZE + threadIdx.x] = __high2float(tmp);
141
  }
@@ -231,8 +210,9 @@ static __global__ void flash_attn_tile_ext_f32(
231
  for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
232
  const int i = i0 + threadIdx.x;
233
 
234
- KV_tmp2[k*(D/2) + i].x = __low2float(V_h2[(k_VKQ_0 + k)*stride_KV2 + i]);
235
- KV_tmp2[k*(D/2) + i].y = __high2float(V_h2[(k_VKQ_0 + k)*stride_KV2 + i]);
 
236
  }
237
  }
238
 
@@ -312,7 +292,6 @@ static __global__ void flash_attn_tile_ext_f32(
312
  GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
313
  GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
314
  GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
315
- GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);
316
  NO_DEVICE_CODE;
317
  #endif // FLASH_ATTN_AVAILABLE
318
  }
 
21
  const float m1,
22
  const uint32_t n_head_log2,
23
  const float logit_softcap,
24
+ const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
25
+ const int32_t nb01, const int32_t nb02, const int32_t nb03,
26
+ const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
27
+ const int32_t nb11, const int32_t nb12, const int64_t nb13,
28
+ const int32_t nb21, const int32_t nb22, const int64_t nb23,
29
+ const int32_t ne31, const int32_t ne32, const int32_t ne33,
30
+ const int32_t nb31, const int32_t nb32, const int64_t nb33) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  #ifdef FLASH_ATTN_AVAILABLE
32
 
33
  // Skip unused kernel variants for faster compilation:
 
46
  GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
47
  GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
48
  GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
49
+ GGML_UNUSED(nb23);
 
50
  NO_DEVICE_CODE;
51
  return;
52
  }
 
114
 
115
  #pragma unroll
116
  for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 2*WARP_SIZE) {
117
+ const half2 tmp = K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + threadIdx.x];
118
  KV_tmp[i_KQ][k_KQ_0 + 0*WARP_SIZE + threadIdx.x] = __low2float(tmp);
119
  KV_tmp[i_KQ][k_KQ_0 + 1*WARP_SIZE + threadIdx.x] = __high2float(tmp);
120
  }
 
210
  for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
211
  const int i = i0 + threadIdx.x;
212
 
213
+ const half2 tmp = V_h2[int64_t(k_VKQ_0 + k)*stride_KV2 + i];
214
+ KV_tmp2[k*(D/2) + i].x = __low2float(tmp);
215
+ KV_tmp2[k*(D/2) + i].y = __high2float(tmp);
216
  }
217
  }
218
 
 
292
  GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
293
  GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
294
  GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
 
295
  NO_DEVICE_CODE;
296
  #endif // FLASH_ATTN_AVAILABLE
297
  }
ggml/src/ggml-cuda/fattn-vec-f16.cuh CHANGED
@@ -18,33 +18,13 @@ static __global__ void flash_attn_vec_ext_f16(
18
  const float m1,
19
  const uint32_t n_head_log2,
20
  const float logit_softcap,
21
- const int ne00,
22
- const int ne01,
23
- const int ne02,
24
- const int ne03,
25
- const int ne10,
26
- const int ne11,
27
- const int ne12,
28
- const int ne13,
29
- const int ne31,
30
- const int ne32,
31
- const int ne33,
32
- const int nb31,
33
- const int nb32,
34
- const int nb33,
35
- const int nb01,
36
- const int nb02,
37
- const int nb03,
38
- const int nb11,
39
- const int nb12,
40
- const int nb13,
41
- const int nb21,
42
- const int nb22,
43
- const int nb23,
44
- const int ne0,
45
- const int ne1,
46
- const int ne2,
47
- const int ne3) {
48
  #if defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
49
 
50
  // Skip unused kernel variants for faster compilation:
@@ -191,13 +171,16 @@ static __global__ void flash_attn_vec_ext_f16(
191
 
192
  half2 VKQ[ncols] = {{0.0f, 0.0f}};
193
 
 
 
 
194
  for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) {
195
  // Calculate KQ tile and keep track of new maximum KQ values:
196
 
197
  if (mask) {
198
  #pragma unroll
199
  for (int j = 0; j < ncols; ++j) {
200
- maskh_shared[j*D + tid] = slopeh*maskh[j*ne11 + k_VKQ_0 + tid];
201
  }
202
 
203
  __syncthreads();
@@ -244,7 +227,7 @@ static __global__ void flash_attn_vec_ext_f16(
244
 
245
  #pragma unroll
246
  for (int j = 0; j < ncols; ++j) {
247
- half sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_h2[j], Q_i32[j], Q_ds[j]);
248
  sum = warp_reduce_sum((float)sum);
249
 
250
  if (use_logit_softcap) {
@@ -300,14 +283,18 @@ static __global__ void flash_attn_vec_ext_f16(
300
  }
301
 
302
  half2 V_k;
303
- reinterpret_cast<half&>(V_k.x) = dequantize_1_v(V + (k_VKQ_0 + k0 + 0)*nb21, tid);
304
- reinterpret_cast<half&>(V_k.y) = dequantize_1_v(V + (k_VKQ_0 + k0 + 1)*nb21, tid);
305
  #pragma unroll
306
  for (int j = 0; j < ncols; ++j) {
307
  VKQ[j] += V_k*KQ2[j*(D/2) + k0/2];
308
  }
309
  }
310
 
 
 
 
 
311
  __syncthreads();
312
  }
313
 
@@ -351,8 +338,7 @@ static __global__ void flash_attn_vec_ext_f16(
351
  GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
352
  GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
353
  GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
354
- GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
355
- GGML_UNUSED(ne2); GGML_UNUSED(ne3);
356
  NO_DEVICE_CODE;
357
  #endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
358
  }
 
18
  const float m1,
19
  const uint32_t n_head_log2,
20
  const float logit_softcap,
21
+ const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
22
+ const int32_t nb01, const int32_t nb02, const int32_t nb03,
23
+ const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
24
+ const int32_t nb11, const int32_t nb12, const int64_t nb13,
25
+ const int32_t nb21, const int32_t nb22, const int64_t nb23,
26
+ const int32_t ne31, const int32_t ne32, const int32_t ne33,
27
+ const int32_t nb31, const int32_t nb32, const int64_t nb33) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  #if defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
29
 
30
  // Skip unused kernel variants for faster compilation:
 
171
 
172
  half2 VKQ[ncols] = {{0.0f, 0.0f}};
173
 
174
+ K += blockIdx.y*D * nb11;
175
+ V += blockIdx.y*D * nb21;
176
+ maskh += blockIdx.y*D;
177
  for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) {
178
  // Calculate KQ tile and keep track of new maximum KQ values:
179
 
180
  if (mask) {
181
  #pragma unroll
182
  for (int j = 0; j < ncols; ++j) {
183
+ maskh_shared[j*D + tid] = slopeh*maskh[j*ne11 + tid];
184
  }
185
 
186
  __syncthreads();
 
227
 
228
  #pragma unroll
229
  for (int j = 0; j < ncols; ++j) {
230
+ half sum = vec_dot_KQ(K + i_KQ*nb11, Q_h2[j], Q_i32[j], Q_ds[j]);
231
  sum = warp_reduce_sum((float)sum);
232
 
233
  if (use_logit_softcap) {
 
283
  }
284
 
285
  half2 V_k;
286
+ reinterpret_cast<half&>(V_k.x) = dequantize_1_v(V + (k0 + 0)*nb21, tid);
287
+ reinterpret_cast<half&>(V_k.y) = dequantize_1_v(V + (k0 + 1)*nb21, tid);
288
  #pragma unroll
289
  for (int j = 0; j < ncols; ++j) {
290
  VKQ[j] += V_k*KQ2[j*(D/2) + k0/2];
291
  }
292
  }
293
 
294
+ K += gridDim.y*D * nb11;
295
+ V += gridDim.y*D * nb21;
296
+ maskh += gridDim.y*D;
297
+
298
  __syncthreads();
299
  }
300
 
 
338
  GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
339
  GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
340
  GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
341
+ GGML_UNUSED(nb23);
 
342
  NO_DEVICE_CODE;
343
  #endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
344
  }
ggml/src/ggml-cuda/fattn-vec-f32.cuh CHANGED
@@ -18,33 +18,13 @@ static __global__ void flash_attn_vec_ext_f32(
18
  const float m1,
19
  const uint32_t n_head_log2,
20
  const float logit_softcap,
21
- const int ne00,
22
- const int ne01,
23
- const int ne02,
24
- const int ne03,
25
- const int ne10,
26
- const int ne11,
27
- const int ne12,
28
- const int ne13,
29
- const int ne31,
30
- const int ne32,
31
- const int ne33,
32
- const int nb31,
33
- const int nb32,
34
- const int nb33,
35
- const int nb01,
36
- const int nb02,
37
- const int nb03,
38
- const int nb11,
39
- const int nb12,
40
- const int nb13,
41
- const int nb21,
42
- const int nb22,
43
- const int nb23,
44
- const int ne0,
45
- const int ne1,
46
- const int ne2,
47
- const int ne3) {
48
  #ifdef FLASH_ATTN_AVAILABLE
49
 
50
  // Skip unused kernel variants for faster compilation:
@@ -59,8 +39,7 @@ static __global__ void flash_attn_vec_ext_f32(
59
  GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
60
  GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
61
  GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
62
- GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
63
- GGML_UNUSED(ne2); GGML_UNUSED(ne3);
64
  NO_DEVICE_CODE;
65
  return;
66
  }
@@ -198,13 +177,16 @@ static __global__ void flash_attn_vec_ext_f32(
198
 
199
  float VKQ[ncols] = {0.0f};
200
 
 
 
 
201
  for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) {
202
  // Calculate KQ tile and keep track of new maximum KQ values:
203
 
204
  if (mask) {
205
  #pragma unroll
206
  for (int j = 0; j < ncols; ++j) {
207
- maskf_shared[j*D + tid] = slope*__half2float(maskh[j*ne11 + k_VKQ_0 + tid]);
208
  }
209
 
210
  __syncthreads();
@@ -246,7 +228,7 @@ static __global__ void flash_attn_vec_ext_f32(
246
 
247
  #pragma unroll
248
  for (int j = 0; j < ncols; ++j) {
249
- float sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_f2[j], Q_i32[j], Q_ds[j]);
250
  sum = warp_reduce_sum(sum);
251
 
252
  if (use_logit_softcap) {
@@ -297,13 +279,17 @@ static __global__ void flash_attn_vec_ext_f32(
297
  break;
298
  }
299
 
300
- const float V_ki = dequantize_1_v(V + (k_VKQ_0 + k)*nb21, tid);
301
  #pragma unroll
302
  for (int j = 0; j < ncols; ++j) {
303
  VKQ[j] += V_ki*KQ[j*D + k];
304
  }
305
  }
306
 
 
 
 
 
307
  __syncthreads();
308
  }
309
 
@@ -348,7 +334,6 @@ static __global__ void flash_attn_vec_ext_f32(
348
  GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
349
  GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
350
  GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
351
- GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);
352
  NO_DEVICE_CODE;
353
  #endif // FLASH_ATTN_AVAILABLE
354
  }
 
18
  const float m1,
19
  const uint32_t n_head_log2,
20
  const float logit_softcap,
21
+ const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
22
+ const int32_t nb01, const int32_t nb02, const int32_t nb03,
23
+ const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
24
+ const int32_t nb11, const int32_t nb12, const int64_t nb13,
25
+ const int32_t nb21, const int32_t nb22, const int64_t nb23,
26
+ const int32_t ne31, const int32_t ne32, const int32_t ne33,
27
+ const int32_t nb31, const int32_t nb32, const int64_t nb33) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  #ifdef FLASH_ATTN_AVAILABLE
29
 
30
  // Skip unused kernel variants for faster compilation:
 
39
  GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
40
  GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
41
  GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
42
+ GGML_UNUSED(nb23);
 
43
  NO_DEVICE_CODE;
44
  return;
45
  }
 
177
 
178
  float VKQ[ncols] = {0.0f};
179
 
180
+ K += blockIdx.y*D * nb11;
181
+ V += blockIdx.y*D * nb21;
182
+ maskh += blockIdx.y*D;
183
  for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) {
184
  // Calculate KQ tile and keep track of new maximum KQ values:
185
 
186
  if (mask) {
187
  #pragma unroll
188
  for (int j = 0; j < ncols; ++j) {
189
+ maskf_shared[j*D + tid] = slope*__half2float(maskh[j*ne11 + tid]);
190
  }
191
 
192
  __syncthreads();
 
228
 
229
  #pragma unroll
230
  for (int j = 0; j < ncols; ++j) {
231
+ float sum = vec_dot_KQ(K + i_KQ*nb11, Q_f2[j], Q_i32[j], Q_ds[j]);
232
  sum = warp_reduce_sum(sum);
233
 
234
  if (use_logit_softcap) {
 
279
  break;
280
  }
281
 
282
+ const float V_ki = dequantize_1_v(V + k*nb21, tid);
283
  #pragma unroll
284
  for (int j = 0; j < ncols; ++j) {
285
  VKQ[j] += V_ki*KQ[j*D + k];
286
  }
287
  }
288
 
289
+ K += gridDim.y*D * nb11;
290
+ V += gridDim.y*D * nb21;
291
+ maskh += gridDim.y*D;
292
+
293
  __syncthreads();
294
  }
295
 
 
334
  GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
335
  GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
336
  GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
 
337
  NO_DEVICE_CODE;
338
  #endif // FLASH_ATTN_AVAILABLE
339
  }
ggml/src/ggml-cuda/fattn-wmma-f16.cu CHANGED
@@ -37,33 +37,13 @@ static __global__ void flash_attn_ext_f16(
37
  const float m1,
38
  const uint32_t n_head_log2,
39
  const float logit_softcap,
40
- const int ne00,
41
- const int ne01,
42
- const int ne02,
43
- const int ne03,
44
- const int ne10,
45
- const int ne11,
46
- const int ne12,
47
- const int ne13,
48
- const int ne31,
49
- const int ne32,
50
- const int ne33,
51
- const int nb31,
52
- const int nb32,
53
- const int nb33,
54
- const int nb01,
55
- const int nb02,
56
- const int nb03,
57
- const int nb11,
58
- const int nb12,
59
- const int nb13,
60
- const int nb21,
61
- const int nb22,
62
- const int nb23,
63
- const int ne0,
64
- const int ne1,
65
- const int ne2,
66
- const int ne3) {
67
  #if defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)))
68
  // Skip unused kernel variants for faster compilation:
69
  if (use_logit_softcap && !(D == 128 || D == 256)) {
@@ -197,7 +177,7 @@ static __global__ void flash_attn_ext_f16(
197
  #pragma unroll
198
  for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
199
  frag_a_K K_a;
200
- wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
201
  #pragma unroll
202
  for (int j = 0; j < ncols/frag_n; ++j) {
203
  wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
@@ -344,7 +324,7 @@ static __global__ void flash_attn_ext_f16(
344
  const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
345
 
346
  frag_a_V v_a;
347
- wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
348
  #pragma unroll
349
  for (int j = 0; j < ncols/frag_n; ++j) {
350
  wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
@@ -451,7 +431,6 @@ static __global__ void flash_attn_ext_f16(
451
  GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
452
  GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
453
  GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
454
- GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);
455
  NO_DEVICE_CODE;
456
  #endif // defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)))
457
  }
 
37
  const float m1,
38
  const uint32_t n_head_log2,
39
  const float logit_softcap,
40
+ const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
41
+ const int32_t nb01, const int32_t nb02, const int32_t nb03,
42
+ const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
43
+ const int32_t nb11, const int32_t nb12, const int64_t nb13,
44
+ const int32_t nb21, const int32_t nb22, const int64_t nb23,
45
+ const int32_t ne31, const int32_t ne32, const int32_t ne33,
46
+ const int32_t nb31, const int32_t nb32, const int64_t nb33) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  #if defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)))
48
  // Skip unused kernel variants for faster compilation:
49
  if (use_logit_softcap && !(D == 128 || D == 256)) {
 
177
  #pragma unroll
178
  for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
179
  frag_a_K K_a;
180
+ wmma::load_matrix_sync(K_a, K_h + int64_t(k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
181
  #pragma unroll
182
  for (int j = 0; j < ncols/frag_n; ++j) {
183
  wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
 
324
  const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
325
 
326
  frag_a_V v_a;
327
+ wmma::load_matrix_sync(v_a, V_h + int64_t(k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
328
  #pragma unroll
329
  for (int j = 0; j < ncols/frag_n; ++j) {
330
  wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
 
431
  GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
432
  GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
433
  GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
 
434
  NO_DEVICE_CODE;
435
  #endif // defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)))
436
  }
ggml/src/ggml-cuda/fattn.cu CHANGED
@@ -280,22 +280,12 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
280
  const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
281
  const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
282
 
283
- if (GGML_CUDA_CC_IS_AMD(cc)) {
284
  #if defined(GGML_HIP_ROCWMMA_FATTN)
285
- if (fp16_mma_available(cc)) {
286
- ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
287
- return;
288
- }
289
- #endif // defined(GGML_HIP_ROCWMMA_FATTN)
290
-
291
- // On AMD the tile kernels perform poorly, use the vec kernel instead:
292
- if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
293
- ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
294
- } else {
295
- ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
296
- }
297
  return;
298
  }
 
299
 
300
  if (!fast_fp16_available(cc)) {
301
  if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
 
280
  const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
281
  const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
282
 
 
283
  #if defined(GGML_HIP_ROCWMMA_FATTN)
284
+ if (GGML_CUDA_CC_IS_AMD(cc) && fp16_mma_available(cc)) {
285
+ ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
 
 
 
 
 
 
 
 
 
 
286
  return;
287
  }
288
+ #endif // defined(GGML_HIP_ROCWMMA_FATTN)
289
 
290
  if (!fast_fp16_available(cc)) {
291
  if (Q->ne[1] <= 8 || Q->ne[0] == 256) {