Spaces:
Running
Running
Commit
·
10ac92f
1
Parent(s):
2746afd
CUDA: fix overflow in FA, tune performance (llama/14840)
Browse files- ggml/src/ggml-cuda/fattn-common.cuh +11 -34
- ggml/src/ggml-cuda/fattn-mma-f16.cuh +16 -39
- ggml/src/ggml-cuda/fattn-tile-f16.cu +10 -31
- ggml/src/ggml-cuda/fattn-tile-f32.cu +12 -33
- ggml/src/ggml-cuda/fattn-vec-f16.cuh +19 -33
- ggml/src/ggml-cuda/fattn-vec-f32.cuh +18 -33
- ggml/src/ggml-cuda/fattn-wmma-f16.cu +9 -30
- ggml/src/ggml-cuda/fattn.cu +3 -13
ggml/src/ggml-cuda/fattn-common.cuh
CHANGED
|
@@ -23,33 +23,13 @@ typedef void (* fattn_kernel_t)(
|
|
| 23 |
const float m1,
|
| 24 |
const uint32_t n_head_log2,
|
| 25 |
const float logit_softcap,
|
| 26 |
-
const
|
| 27 |
-
|
| 28 |
-
const
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
const int ne13,
|
| 34 |
-
const int ne31,
|
| 35 |
-
const int ne32,
|
| 36 |
-
const int ne33,
|
| 37 |
-
const int nb31,
|
| 38 |
-
const int nb32,
|
| 39 |
-
const int nb33,
|
| 40 |
-
const int nb01,
|
| 41 |
-
const int nb02,
|
| 42 |
-
const int nb03,
|
| 43 |
-
const int nb11,
|
| 44 |
-
const int nb12,
|
| 45 |
-
const int nb13,
|
| 46 |
-
const int nb21,
|
| 47 |
-
const int nb22,
|
| 48 |
-
const int nb23,
|
| 49 |
-
const int ne0,
|
| 50 |
-
const int ne1,
|
| 51 |
-
const int ne2,
|
| 52 |
-
const int ne3);
|
| 53 |
|
| 54 |
typedef half (*vec_dot_KQ_f16_t)(
|
| 55 |
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds);
|
|
@@ -892,14 +872,11 @@ void launch_fattn(
|
|
| 892 |
mask ? ((const char *) mask->data) : nullptr,
|
| 893 |
!stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
|
| 894 |
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
|
| 895 |
-
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
| 896 |
-
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
| 897 |
-
mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0,
|
| 898 |
-
mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0, mask ? mask->nb[3] : 0,
|
| 899 |
-
Q->nb[1], Q->nb[2], Q->nb[3],
|
| 900 |
-
nb11, nb12, nb13,
|
| 901 |
nb21, nb22, nb23,
|
| 902 |
-
|
|
|
|
| 903 |
);
|
| 904 |
CUDA_CHECK(cudaGetLastError());
|
| 905 |
|
|
|
|
| 23 |
const float m1,
|
| 24 |
const uint32_t n_head_log2,
|
| 25 |
const float logit_softcap,
|
| 26 |
+
const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
|
| 27 |
+
const int32_t nb01, const int32_t nb02, const int32_t nb03,
|
| 28 |
+
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
|
| 29 |
+
const int32_t nb11, const int32_t nb12, const int64_t nb13,
|
| 30 |
+
const int32_t nb21, const int32_t nb22, const int64_t nb23,
|
| 31 |
+
const int32_t ne31, const int32_t ne32, const int32_t ne33,
|
| 32 |
+
const int32_t nb31, const int32_t nb32, const int64_t nb33);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
typedef half (*vec_dot_KQ_f16_t)(
|
| 35 |
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds);
|
|
|
|
| 872 |
mask ? ((const char *) mask->data) : nullptr,
|
| 873 |
!stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
|
| 874 |
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
|
| 875 |
+
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], Q->nb[1], Q->nb[2], Q->nb[3],
|
| 876 |
+
K->ne[0], K->ne[1], K->ne[2], K->ne[3], nb11, nb12, nb13,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 877 |
nb21, nb22, nb23,
|
| 878 |
+
mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0,
|
| 879 |
+
mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0, mask ? mask->nb[3] : 0
|
| 880 |
);
|
| 881 |
CUDA_CHECK(cudaGetLastError());
|
| 882 |
|
ggml/src/ggml-cuda/fattn-mma-f16.cuh
CHANGED
|
@@ -408,7 +408,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
| 408 |
const int stride_K,
|
| 409 |
const int stride_V,
|
| 410 |
const int stride_mask,
|
| 411 |
-
const int jt,
|
| 412 |
half2 * const __restrict__ tile_Q,
|
| 413 |
half2 * const __restrict__ tile_K,
|
| 414 |
half2 * const __restrict__ tile_V,
|
|
@@ -455,7 +454,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
| 455 |
cp_async_wait_all();
|
| 456 |
__syncthreads();
|
| 457 |
flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
|
| 458 |
-
(V_h2 + k_VKQ_0*stride_V, tile_V, nbatch_V2, stride_V);
|
| 459 |
} else {
|
| 460 |
constexpr bool use_cp_async = nstages == 1;
|
| 461 |
if (ncols2 > 1 || mask_h2) {
|
|
@@ -471,7 +470,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
| 471 |
if (nstages <= 1) {
|
| 472 |
constexpr bool use_cp_async = nstages == 1;
|
| 473 |
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
|
| 474 |
-
(K_h2 + k_VKQ_0*stride_K + k0_start, tile_K, k0_diff, stride_K);
|
| 475 |
if (use_cp_async) {
|
| 476 |
cp_async_wait_all();
|
| 477 |
}
|
|
@@ -715,7 +714,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
| 715 |
(mask_h2 + (k_VKQ_0 + c::nbatch_fa)/2, tile_mask, stride_mask);
|
| 716 |
}
|
| 717 |
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
|
| 718 |
-
(K_h2 + (k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K);
|
| 719 |
}
|
| 720 |
}
|
| 721 |
|
|
@@ -732,7 +731,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
| 732 |
if (nstages <= 1 && i0_start < reusable_cutoff) {
|
| 733 |
constexpr bool use_cp_async = nstages == 1;
|
| 734 |
flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
|
| 735 |
-
(V_h2 + k_VKQ_0*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V);
|
| 736 |
if (use_cp_async) {
|
| 737 |
cp_async_wait_all();
|
| 738 |
}
|
|
@@ -771,8 +770,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
| 771 |
GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup);
|
| 772 |
GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap);
|
| 773 |
GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_K); GGML_UNUSED(stride_V);
|
| 774 |
-
GGML_UNUSED(stride_mask); GGML_UNUSED(
|
| 775 |
-
GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K);
|
| 776 |
GGML_UNUSED(tile_V); GGML_UNUSED(tile_mask); GGML_UNUSED(Q_B);
|
| 777 |
GGML_UNUSED(VKQ_C); GGML_UNUSED(KQ_max); GGML_UNUSED(KQ_rowsum);
|
| 778 |
GGML_UNUSED(kb0); GGML_UNUSED(tile_Q);
|
|
@@ -920,7 +918,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
| 920 |
(mask_h2 + kb0_start*c::nbatch_fa/2, tile_mask, stride_mask);
|
| 921 |
}
|
| 922 |
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
|
| 923 |
-
(K_h2 + kb0_start*c::nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K);
|
| 924 |
}
|
| 925 |
|
| 926 |
// Iterate over ne11 == previous tokens:
|
|
@@ -928,13 +926,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
| 928 |
constexpr bool last_iter = false;
|
| 929 |
flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
|
| 930 |
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
|
| 931 |
-
ne01, ne02, stride_K, stride_V, stride_mask,
|
| 932 |
}
|
| 933 |
{ // kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
|
| 934 |
constexpr bool last_iter = true;
|
| 935 |
flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
|
| 936 |
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
|
| 937 |
-
ne01, ne02, stride_K, stride_V, stride_mask,
|
| 938 |
}
|
| 939 |
|
| 940 |
// With multi-stage loading there is no __syncthreads at the end of the iter,
|
|
@@ -1214,33 +1212,13 @@ static __global__ void flash_attn_ext_f16(
|
|
| 1214 |
const float m1,
|
| 1215 |
const uint32_t n_head_log2,
|
| 1216 |
const float logit_softcap,
|
| 1217 |
-
const
|
| 1218 |
-
|
| 1219 |
-
const
|
| 1220 |
-
|
| 1221 |
-
|
| 1222 |
-
|
| 1223 |
-
|
| 1224 |
-
const int ne13,
|
| 1225 |
-
const int ne31,
|
| 1226 |
-
const int ne32,
|
| 1227 |
-
const int ne33,
|
| 1228 |
-
const int nb31,
|
| 1229 |
-
const int nb32,
|
| 1230 |
-
const int nb33,
|
| 1231 |
-
const int nb01,
|
| 1232 |
-
const int nb02,
|
| 1233 |
-
const int nb03,
|
| 1234 |
-
const int nb11,
|
| 1235 |
-
const int nb12,
|
| 1236 |
-
const int nb13,
|
| 1237 |
-
const int nb21,
|
| 1238 |
-
const int nb22,
|
| 1239 |
-
const int nb23,
|
| 1240 |
-
const int ne0,
|
| 1241 |
-
const int ne1,
|
| 1242 |
-
const int ne2,
|
| 1243 |
-
const int ne3) {
|
| 1244 |
#if defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
|
| 1245 |
|
| 1246 |
// Skip unused kernel variants for faster compilation:
|
|
@@ -1359,8 +1337,7 @@ static __global__ void flash_attn_ext_f16(
|
|
| 1359 |
GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
|
| 1360 |
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
|
| 1361 |
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21);
|
| 1362 |
-
GGML_UNUSED(nb22); GGML_UNUSED(nb23);
|
| 1363 |
-
GGML_UNUSED(ne2); GGML_UNUSED(ne3);
|
| 1364 |
NO_DEVICE_CODE;
|
| 1365 |
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
|
| 1366 |
}
|
|
|
|
| 408 |
const int stride_K,
|
| 409 |
const int stride_V,
|
| 410 |
const int stride_mask,
|
|
|
|
| 411 |
half2 * const __restrict__ tile_Q,
|
| 412 |
half2 * const __restrict__ tile_K,
|
| 413 |
half2 * const __restrict__ tile_V,
|
|
|
|
| 454 |
cp_async_wait_all();
|
| 455 |
__syncthreads();
|
| 456 |
flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
|
| 457 |
+
(V_h2 + int64_t(k_VKQ_0)*stride_V, tile_V, nbatch_V2, stride_V);
|
| 458 |
} else {
|
| 459 |
constexpr bool use_cp_async = nstages == 1;
|
| 460 |
if (ncols2 > 1 || mask_h2) {
|
|
|
|
| 470 |
if (nstages <= 1) {
|
| 471 |
constexpr bool use_cp_async = nstages == 1;
|
| 472 |
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
|
| 473 |
+
(K_h2 + int64_t(k_VKQ_0)*stride_K + k0_start, tile_K, k0_diff, stride_K);
|
| 474 |
if (use_cp_async) {
|
| 475 |
cp_async_wait_all();
|
| 476 |
}
|
|
|
|
| 714 |
(mask_h2 + (k_VKQ_0 + c::nbatch_fa)/2, tile_mask, stride_mask);
|
| 715 |
}
|
| 716 |
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
|
| 717 |
+
(K_h2 + int64_t(k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K);
|
| 718 |
}
|
| 719 |
}
|
| 720 |
|
|
|
|
| 731 |
if (nstages <= 1 && i0_start < reusable_cutoff) {
|
| 732 |
constexpr bool use_cp_async = nstages == 1;
|
| 733 |
flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
|
| 734 |
+
(V_h2 + int64_t(k_VKQ_0)*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V);
|
| 735 |
if (use_cp_async) {
|
| 736 |
cp_async_wait_all();
|
| 737 |
}
|
|
|
|
| 770 |
GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup);
|
| 771 |
GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap);
|
| 772 |
GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_K); GGML_UNUSED(stride_V);
|
| 773 |
+
GGML_UNUSED(stride_mask); GGML_UNUSED(tile_K);
|
|
|
|
| 774 |
GGML_UNUSED(tile_V); GGML_UNUSED(tile_mask); GGML_UNUSED(Q_B);
|
| 775 |
GGML_UNUSED(VKQ_C); GGML_UNUSED(KQ_max); GGML_UNUSED(KQ_rowsum);
|
| 776 |
GGML_UNUSED(kb0); GGML_UNUSED(tile_Q);
|
|
|
|
| 918 |
(mask_h2 + kb0_start*c::nbatch_fa/2, tile_mask, stride_mask);
|
| 919 |
}
|
| 920 |
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
|
| 921 |
+
(K_h2 + int64_t(kb0_start)*c::nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K);
|
| 922 |
}
|
| 923 |
|
| 924 |
// Iterate over ne11 == previous tokens:
|
|
|
|
| 926 |
constexpr bool last_iter = false;
|
| 927 |
flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
|
| 928 |
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
|
| 929 |
+
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
|
| 930 |
}
|
| 931 |
{ // kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
|
| 932 |
constexpr bool last_iter = true;
|
| 933 |
flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
|
| 934 |
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
|
| 935 |
+
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1);
|
| 936 |
}
|
| 937 |
|
| 938 |
// With multi-stage loading there is no __syncthreads at the end of the iter,
|
|
|
|
| 1212 |
const float m1,
|
| 1213 |
const uint32_t n_head_log2,
|
| 1214 |
const float logit_softcap,
|
| 1215 |
+
const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
|
| 1216 |
+
const int32_t nb01, const int32_t nb02, const int32_t nb03,
|
| 1217 |
+
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
|
| 1218 |
+
const int32_t nb11, const int32_t nb12, const int64_t nb13,
|
| 1219 |
+
const int32_t nb21, const int32_t nb22, const int64_t nb23,
|
| 1220 |
+
const int32_t ne31, const int32_t ne32, const int32_t ne33,
|
| 1221 |
+
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1222 |
#if defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
|
| 1223 |
|
| 1224 |
// Skip unused kernel variants for faster compilation:
|
|
|
|
| 1337 |
GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
|
| 1338 |
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
|
| 1339 |
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21);
|
| 1340 |
+
GGML_UNUSED(nb22); GGML_UNUSED(nb23);
|
|
|
|
| 1341 |
NO_DEVICE_CODE;
|
| 1342 |
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
|
| 1343 |
}
|
ggml/src/ggml-cuda/fattn-tile-f16.cu
CHANGED
|
@@ -21,33 +21,13 @@ static __global__ void flash_attn_tile_ext_f16(
|
|
| 21 |
const float m1,
|
| 22 |
const uint32_t n_head_log2,
|
| 23 |
const float logit_softcap,
|
| 24 |
-
const
|
| 25 |
-
|
| 26 |
-
const
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
const int ne13,
|
| 32 |
-
const int ne31,
|
| 33 |
-
const int ne32,
|
| 34 |
-
const int ne33,
|
| 35 |
-
const int nb31,
|
| 36 |
-
const int nb32,
|
| 37 |
-
const int nb33,
|
| 38 |
-
const int nb01,
|
| 39 |
-
const int nb02,
|
| 40 |
-
const int nb03,
|
| 41 |
-
const int nb11,
|
| 42 |
-
const int nb12,
|
| 43 |
-
const int nb13,
|
| 44 |
-
const int nb21,
|
| 45 |
-
const int nb22,
|
| 46 |
-
const int nb23,
|
| 47 |
-
const int ne0,
|
| 48 |
-
const int ne1,
|
| 49 |
-
const int ne2,
|
| 50 |
-
const int ne3) {
|
| 51 |
#if defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
|
| 52 |
|
| 53 |
// Skip unused kernel variants for faster compilation:
|
|
@@ -127,7 +107,7 @@ static __global__ void flash_attn_tile_ext_f16(
|
|
| 127 |
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
|
| 128 |
const int k_KQ = k_KQ_0 + threadIdx.x;
|
| 129 |
|
| 130 |
-
KV_tmp[i_KQ][k_KQ] = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ];
|
| 131 |
}
|
| 132 |
}
|
| 133 |
|
|
@@ -221,7 +201,7 @@ static __global__ void flash_attn_tile_ext_f16(
|
|
| 221 |
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
| 222 |
const int i = i0 + threadIdx.x;
|
| 223 |
|
| 224 |
-
KV_tmp[k][i] = V_h2[(k_VKQ_0 + k)*stride_KV2 + i];
|
| 225 |
}
|
| 226 |
}
|
| 227 |
|
|
@@ -300,8 +280,7 @@ static __global__ void flash_attn_tile_ext_f16(
|
|
| 300 |
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
| 301 |
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
|
| 302 |
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
|
| 303 |
-
GGML_UNUSED(nb23);
|
| 304 |
-
GGML_UNUSED(ne2); GGML_UNUSED(ne3);
|
| 305 |
NO_DEVICE_CODE;
|
| 306 |
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
|
| 307 |
}
|
|
|
|
| 21 |
const float m1,
|
| 22 |
const uint32_t n_head_log2,
|
| 23 |
const float logit_softcap,
|
| 24 |
+
const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
|
| 25 |
+
const int32_t nb01, const int32_t nb02, const int32_t nb03,
|
| 26 |
+
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
|
| 27 |
+
const int32_t nb11, const int32_t nb12, const int64_t nb13,
|
| 28 |
+
const int32_t nb21, const int32_t nb22, const int64_t nb23,
|
| 29 |
+
const int32_t ne31, const int32_t ne32, const int32_t ne33,
|
| 30 |
+
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
#if defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
|
| 32 |
|
| 33 |
// Skip unused kernel variants for faster compilation:
|
|
|
|
| 107 |
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
|
| 108 |
const int k_KQ = k_KQ_0 + threadIdx.x;
|
| 109 |
|
| 110 |
+
KV_tmp[i_KQ][k_KQ] = K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ];
|
| 111 |
}
|
| 112 |
}
|
| 113 |
|
|
|
|
| 201 |
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
| 202 |
const int i = i0 + threadIdx.x;
|
| 203 |
|
| 204 |
+
KV_tmp[k][i] = V_h2[int64_t(k_VKQ_0 + k)*stride_KV2 + i];
|
| 205 |
}
|
| 206 |
}
|
| 207 |
|
|
|
|
| 280 |
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
| 281 |
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
|
| 282 |
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
|
| 283 |
+
GGML_UNUSED(nb23);
|
|
|
|
| 284 |
NO_DEVICE_CODE;
|
| 285 |
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
|
| 286 |
}
|
ggml/src/ggml-cuda/fattn-tile-f32.cu
CHANGED
|
@@ -21,33 +21,13 @@ static __global__ void flash_attn_tile_ext_f32(
|
|
| 21 |
const float m1,
|
| 22 |
const uint32_t n_head_log2,
|
| 23 |
const float logit_softcap,
|
| 24 |
-
const
|
| 25 |
-
|
| 26 |
-
const
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
const int ne13,
|
| 32 |
-
const int ne31,
|
| 33 |
-
const int ne32,
|
| 34 |
-
const int ne33,
|
| 35 |
-
const int nb31,
|
| 36 |
-
const int nb32,
|
| 37 |
-
const int nb33,
|
| 38 |
-
const int nb01,
|
| 39 |
-
const int nb02,
|
| 40 |
-
const int nb03,
|
| 41 |
-
const int nb11,
|
| 42 |
-
const int nb12,
|
| 43 |
-
const int nb13,
|
| 44 |
-
const int nb21,
|
| 45 |
-
const int nb22,
|
| 46 |
-
const int nb23,
|
| 47 |
-
const int ne0,
|
| 48 |
-
const int ne1,
|
| 49 |
-
const int ne2,
|
| 50 |
-
const int ne3) {
|
| 51 |
#ifdef FLASH_ATTN_AVAILABLE
|
| 52 |
|
| 53 |
// Skip unused kernel variants for faster compilation:
|
|
@@ -66,8 +46,7 @@ static __global__ void flash_attn_tile_ext_f32(
|
|
| 66 |
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
| 67 |
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
|
| 68 |
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
|
| 69 |
-
GGML_UNUSED(nb23);
|
| 70 |
-
GGML_UNUSED(ne2); GGML_UNUSED(ne3);
|
| 71 |
NO_DEVICE_CODE;
|
| 72 |
return;
|
| 73 |
}
|
|
@@ -135,7 +114,7 @@ static __global__ void flash_attn_tile_ext_f32(
|
|
| 135 |
|
| 136 |
#pragma unroll
|
| 137 |
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 2*WARP_SIZE) {
|
| 138 |
-
const half2 tmp = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + threadIdx.x];
|
| 139 |
KV_tmp[i_KQ][k_KQ_0 + 0*WARP_SIZE + threadIdx.x] = __low2float(tmp);
|
| 140 |
KV_tmp[i_KQ][k_KQ_0 + 1*WARP_SIZE + threadIdx.x] = __high2float(tmp);
|
| 141 |
}
|
|
@@ -231,8 +210,9 @@ static __global__ void flash_attn_tile_ext_f32(
|
|
| 231 |
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
| 232 |
const int i = i0 + threadIdx.x;
|
| 233 |
|
| 234 |
-
|
| 235 |
-
KV_tmp2[k*(D/2) + i].
|
|
|
|
| 236 |
}
|
| 237 |
}
|
| 238 |
|
|
@@ -312,7 +292,6 @@ static __global__ void flash_attn_tile_ext_f32(
|
|
| 312 |
GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
|
| 313 |
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
|
| 314 |
GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
|
| 315 |
-
GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);
|
| 316 |
NO_DEVICE_CODE;
|
| 317 |
#endif // FLASH_ATTN_AVAILABLE
|
| 318 |
}
|
|
|
|
| 21 |
const float m1,
|
| 22 |
const uint32_t n_head_log2,
|
| 23 |
const float logit_softcap,
|
| 24 |
+
const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
|
| 25 |
+
const int32_t nb01, const int32_t nb02, const int32_t nb03,
|
| 26 |
+
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
|
| 27 |
+
const int32_t nb11, const int32_t nb12, const int64_t nb13,
|
| 28 |
+
const int32_t nb21, const int32_t nb22, const int64_t nb23,
|
| 29 |
+
const int32_t ne31, const int32_t ne32, const int32_t ne33,
|
| 30 |
+
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
#ifdef FLASH_ATTN_AVAILABLE
|
| 32 |
|
| 33 |
// Skip unused kernel variants for faster compilation:
|
|
|
|
| 46 |
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
| 47 |
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
|
| 48 |
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
|
| 49 |
+
GGML_UNUSED(nb23);
|
|
|
|
| 50 |
NO_DEVICE_CODE;
|
| 51 |
return;
|
| 52 |
}
|
|
|
|
| 114 |
|
| 115 |
#pragma unroll
|
| 116 |
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 2*WARP_SIZE) {
|
| 117 |
+
const half2 tmp = K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + threadIdx.x];
|
| 118 |
KV_tmp[i_KQ][k_KQ_0 + 0*WARP_SIZE + threadIdx.x] = __low2float(tmp);
|
| 119 |
KV_tmp[i_KQ][k_KQ_0 + 1*WARP_SIZE + threadIdx.x] = __high2float(tmp);
|
| 120 |
}
|
|
|
|
| 210 |
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
| 211 |
const int i = i0 + threadIdx.x;
|
| 212 |
|
| 213 |
+
const half2 tmp = V_h2[int64_t(k_VKQ_0 + k)*stride_KV2 + i];
|
| 214 |
+
KV_tmp2[k*(D/2) + i].x = __low2float(tmp);
|
| 215 |
+
KV_tmp2[k*(D/2) + i].y = __high2float(tmp);
|
| 216 |
}
|
| 217 |
}
|
| 218 |
|
|
|
|
| 292 |
GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
|
| 293 |
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
|
| 294 |
GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
|
|
|
|
| 295 |
NO_DEVICE_CODE;
|
| 296 |
#endif // FLASH_ATTN_AVAILABLE
|
| 297 |
}
|
ggml/src/ggml-cuda/fattn-vec-f16.cuh
CHANGED
|
@@ -18,33 +18,13 @@ static __global__ void flash_attn_vec_ext_f16(
|
|
| 18 |
const float m1,
|
| 19 |
const uint32_t n_head_log2,
|
| 20 |
const float logit_softcap,
|
| 21 |
-
const
|
| 22 |
-
|
| 23 |
-
const
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
const int ne13,
|
| 29 |
-
const int ne31,
|
| 30 |
-
const int ne32,
|
| 31 |
-
const int ne33,
|
| 32 |
-
const int nb31,
|
| 33 |
-
const int nb32,
|
| 34 |
-
const int nb33,
|
| 35 |
-
const int nb01,
|
| 36 |
-
const int nb02,
|
| 37 |
-
const int nb03,
|
| 38 |
-
const int nb11,
|
| 39 |
-
const int nb12,
|
| 40 |
-
const int nb13,
|
| 41 |
-
const int nb21,
|
| 42 |
-
const int nb22,
|
| 43 |
-
const int nb23,
|
| 44 |
-
const int ne0,
|
| 45 |
-
const int ne1,
|
| 46 |
-
const int ne2,
|
| 47 |
-
const int ne3) {
|
| 48 |
#if defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
|
| 49 |
|
| 50 |
// Skip unused kernel variants for faster compilation:
|
|
@@ -191,13 +171,16 @@ static __global__ void flash_attn_vec_ext_f16(
|
|
| 191 |
|
| 192 |
half2 VKQ[ncols] = {{0.0f, 0.0f}};
|
| 193 |
|
|
|
|
|
|
|
|
|
|
| 194 |
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) {
|
| 195 |
// Calculate KQ tile and keep track of new maximum KQ values:
|
| 196 |
|
| 197 |
if (mask) {
|
| 198 |
#pragma unroll
|
| 199 |
for (int j = 0; j < ncols; ++j) {
|
| 200 |
-
maskh_shared[j*D + tid] = slopeh*maskh[j*ne11 +
|
| 201 |
}
|
| 202 |
|
| 203 |
__syncthreads();
|
|
@@ -244,7 +227,7 @@ static __global__ void flash_attn_vec_ext_f16(
|
|
| 244 |
|
| 245 |
#pragma unroll
|
| 246 |
for (int j = 0; j < ncols; ++j) {
|
| 247 |
-
half sum = vec_dot_KQ(K +
|
| 248 |
sum = warp_reduce_sum((float)sum);
|
| 249 |
|
| 250 |
if (use_logit_softcap) {
|
|
@@ -300,14 +283,18 @@ static __global__ void flash_attn_vec_ext_f16(
|
|
| 300 |
}
|
| 301 |
|
| 302 |
half2 V_k;
|
| 303 |
-
reinterpret_cast<half&>(V_k.x) = dequantize_1_v(V + (
|
| 304 |
-
reinterpret_cast<half&>(V_k.y) = dequantize_1_v(V + (
|
| 305 |
#pragma unroll
|
| 306 |
for (int j = 0; j < ncols; ++j) {
|
| 307 |
VKQ[j] += V_k*KQ2[j*(D/2) + k0/2];
|
| 308 |
}
|
| 309 |
}
|
| 310 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 311 |
__syncthreads();
|
| 312 |
}
|
| 313 |
|
|
@@ -351,8 +338,7 @@ static __global__ void flash_attn_vec_ext_f16(
|
|
| 351 |
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
| 352 |
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
|
| 353 |
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
|
| 354 |
-
GGML_UNUSED(nb23);
|
| 355 |
-
GGML_UNUSED(ne2); GGML_UNUSED(ne3);
|
| 356 |
NO_DEVICE_CODE;
|
| 357 |
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
|
| 358 |
}
|
|
|
|
| 18 |
const float m1,
|
| 19 |
const uint32_t n_head_log2,
|
| 20 |
const float logit_softcap,
|
| 21 |
+
const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
|
| 22 |
+
const int32_t nb01, const int32_t nb02, const int32_t nb03,
|
| 23 |
+
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
|
| 24 |
+
const int32_t nb11, const int32_t nb12, const int64_t nb13,
|
| 25 |
+
const int32_t nb21, const int32_t nb22, const int64_t nb23,
|
| 26 |
+
const int32_t ne31, const int32_t ne32, const int32_t ne33,
|
| 27 |
+
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
#if defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
|
| 29 |
|
| 30 |
// Skip unused kernel variants for faster compilation:
|
|
|
|
| 171 |
|
| 172 |
half2 VKQ[ncols] = {{0.0f, 0.0f}};
|
| 173 |
|
| 174 |
+
K += blockIdx.y*D * nb11;
|
| 175 |
+
V += blockIdx.y*D * nb21;
|
| 176 |
+
maskh += blockIdx.y*D;
|
| 177 |
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) {
|
| 178 |
// Calculate KQ tile and keep track of new maximum KQ values:
|
| 179 |
|
| 180 |
if (mask) {
|
| 181 |
#pragma unroll
|
| 182 |
for (int j = 0; j < ncols; ++j) {
|
| 183 |
+
maskh_shared[j*D + tid] = slopeh*maskh[j*ne11 + tid];
|
| 184 |
}
|
| 185 |
|
| 186 |
__syncthreads();
|
|
|
|
| 227 |
|
| 228 |
#pragma unroll
|
| 229 |
for (int j = 0; j < ncols; ++j) {
|
| 230 |
+
half sum = vec_dot_KQ(K + i_KQ*nb11, Q_h2[j], Q_i32[j], Q_ds[j]);
|
| 231 |
sum = warp_reduce_sum((float)sum);
|
| 232 |
|
| 233 |
if (use_logit_softcap) {
|
|
|
|
| 283 |
}
|
| 284 |
|
| 285 |
half2 V_k;
|
| 286 |
+
reinterpret_cast<half&>(V_k.x) = dequantize_1_v(V + (k0 + 0)*nb21, tid);
|
| 287 |
+
reinterpret_cast<half&>(V_k.y) = dequantize_1_v(V + (k0 + 1)*nb21, tid);
|
| 288 |
#pragma unroll
|
| 289 |
for (int j = 0; j < ncols; ++j) {
|
| 290 |
VKQ[j] += V_k*KQ2[j*(D/2) + k0/2];
|
| 291 |
}
|
| 292 |
}
|
| 293 |
|
| 294 |
+
K += gridDim.y*D * nb11;
|
| 295 |
+
V += gridDim.y*D * nb21;
|
| 296 |
+
maskh += gridDim.y*D;
|
| 297 |
+
|
| 298 |
__syncthreads();
|
| 299 |
}
|
| 300 |
|
|
|
|
| 338 |
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
| 339 |
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
|
| 340 |
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
|
| 341 |
+
GGML_UNUSED(nb23);
|
|
|
|
| 342 |
NO_DEVICE_CODE;
|
| 343 |
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
|
| 344 |
}
|
ggml/src/ggml-cuda/fattn-vec-f32.cuh
CHANGED
|
@@ -18,33 +18,13 @@ static __global__ void flash_attn_vec_ext_f32(
|
|
| 18 |
const float m1,
|
| 19 |
const uint32_t n_head_log2,
|
| 20 |
const float logit_softcap,
|
| 21 |
-
const
|
| 22 |
-
|
| 23 |
-
const
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
const int ne13,
|
| 29 |
-
const int ne31,
|
| 30 |
-
const int ne32,
|
| 31 |
-
const int ne33,
|
| 32 |
-
const int nb31,
|
| 33 |
-
const int nb32,
|
| 34 |
-
const int nb33,
|
| 35 |
-
const int nb01,
|
| 36 |
-
const int nb02,
|
| 37 |
-
const int nb03,
|
| 38 |
-
const int nb11,
|
| 39 |
-
const int nb12,
|
| 40 |
-
const int nb13,
|
| 41 |
-
const int nb21,
|
| 42 |
-
const int nb22,
|
| 43 |
-
const int nb23,
|
| 44 |
-
const int ne0,
|
| 45 |
-
const int ne1,
|
| 46 |
-
const int ne2,
|
| 47 |
-
const int ne3) {
|
| 48 |
#ifdef FLASH_ATTN_AVAILABLE
|
| 49 |
|
| 50 |
// Skip unused kernel variants for faster compilation:
|
|
@@ -59,8 +39,7 @@ static __global__ void flash_attn_vec_ext_f32(
|
|
| 59 |
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
| 60 |
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
|
| 61 |
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
|
| 62 |
-
GGML_UNUSED(nb23);
|
| 63 |
-
GGML_UNUSED(ne2); GGML_UNUSED(ne3);
|
| 64 |
NO_DEVICE_CODE;
|
| 65 |
return;
|
| 66 |
}
|
|
@@ -198,13 +177,16 @@ static __global__ void flash_attn_vec_ext_f32(
|
|
| 198 |
|
| 199 |
float VKQ[ncols] = {0.0f};
|
| 200 |
|
|
|
|
|
|
|
|
|
|
| 201 |
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) {
|
| 202 |
// Calculate KQ tile and keep track of new maximum KQ values:
|
| 203 |
|
| 204 |
if (mask) {
|
| 205 |
#pragma unroll
|
| 206 |
for (int j = 0; j < ncols; ++j) {
|
| 207 |
-
maskf_shared[j*D + tid] = slope*__half2float(maskh[j*ne11 +
|
| 208 |
}
|
| 209 |
|
| 210 |
__syncthreads();
|
|
@@ -246,7 +228,7 @@ static __global__ void flash_attn_vec_ext_f32(
|
|
| 246 |
|
| 247 |
#pragma unroll
|
| 248 |
for (int j = 0; j < ncols; ++j) {
|
| 249 |
-
float sum = vec_dot_KQ(K +
|
| 250 |
sum = warp_reduce_sum(sum);
|
| 251 |
|
| 252 |
if (use_logit_softcap) {
|
|
@@ -297,13 +279,17 @@ static __global__ void flash_attn_vec_ext_f32(
|
|
| 297 |
break;
|
| 298 |
}
|
| 299 |
|
| 300 |
-
const float V_ki = dequantize_1_v(V +
|
| 301 |
#pragma unroll
|
| 302 |
for (int j = 0; j < ncols; ++j) {
|
| 303 |
VKQ[j] += V_ki*KQ[j*D + k];
|
| 304 |
}
|
| 305 |
}
|
| 306 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 307 |
__syncthreads();
|
| 308 |
}
|
| 309 |
|
|
@@ -348,7 +334,6 @@ static __global__ void flash_attn_vec_ext_f32(
|
|
| 348 |
GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
|
| 349 |
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
|
| 350 |
GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
|
| 351 |
-
GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);
|
| 352 |
NO_DEVICE_CODE;
|
| 353 |
#endif // FLASH_ATTN_AVAILABLE
|
| 354 |
}
|
|
|
|
| 18 |
const float m1,
|
| 19 |
const uint32_t n_head_log2,
|
| 20 |
const float logit_softcap,
|
| 21 |
+
const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
|
| 22 |
+
const int32_t nb01, const int32_t nb02, const int32_t nb03,
|
| 23 |
+
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
|
| 24 |
+
const int32_t nb11, const int32_t nb12, const int64_t nb13,
|
| 25 |
+
const int32_t nb21, const int32_t nb22, const int64_t nb23,
|
| 26 |
+
const int32_t ne31, const int32_t ne32, const int32_t ne33,
|
| 27 |
+
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
#ifdef FLASH_ATTN_AVAILABLE
|
| 29 |
|
| 30 |
// Skip unused kernel variants for faster compilation:
|
|
|
|
| 39 |
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
| 40 |
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
|
| 41 |
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
|
| 42 |
+
GGML_UNUSED(nb23);
|
|
|
|
| 43 |
NO_DEVICE_CODE;
|
| 44 |
return;
|
| 45 |
}
|
|
|
|
| 177 |
|
| 178 |
float VKQ[ncols] = {0.0f};
|
| 179 |
|
| 180 |
+
K += blockIdx.y*D * nb11;
|
| 181 |
+
V += blockIdx.y*D * nb21;
|
| 182 |
+
maskh += blockIdx.y*D;
|
| 183 |
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) {
|
| 184 |
// Calculate KQ tile and keep track of new maximum KQ values:
|
| 185 |
|
| 186 |
if (mask) {
|
| 187 |
#pragma unroll
|
| 188 |
for (int j = 0; j < ncols; ++j) {
|
| 189 |
+
maskf_shared[j*D + tid] = slope*__half2float(maskh[j*ne11 + tid]);
|
| 190 |
}
|
| 191 |
|
| 192 |
__syncthreads();
|
|
|
|
| 228 |
|
| 229 |
#pragma unroll
|
| 230 |
for (int j = 0; j < ncols; ++j) {
|
| 231 |
+
float sum = vec_dot_KQ(K + i_KQ*nb11, Q_f2[j], Q_i32[j], Q_ds[j]);
|
| 232 |
sum = warp_reduce_sum(sum);
|
| 233 |
|
| 234 |
if (use_logit_softcap) {
|
|
|
|
| 279 |
break;
|
| 280 |
}
|
| 281 |
|
| 282 |
+
const float V_ki = dequantize_1_v(V + k*nb21, tid);
|
| 283 |
#pragma unroll
|
| 284 |
for (int j = 0; j < ncols; ++j) {
|
| 285 |
VKQ[j] += V_ki*KQ[j*D + k];
|
| 286 |
}
|
| 287 |
}
|
| 288 |
|
| 289 |
+
K += gridDim.y*D * nb11;
|
| 290 |
+
V += gridDim.y*D * nb21;
|
| 291 |
+
maskh += gridDim.y*D;
|
| 292 |
+
|
| 293 |
__syncthreads();
|
| 294 |
}
|
| 295 |
|
|
|
|
| 334 |
GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
|
| 335 |
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
|
| 336 |
GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
|
|
|
|
| 337 |
NO_DEVICE_CODE;
|
| 338 |
#endif // FLASH_ATTN_AVAILABLE
|
| 339 |
}
|
ggml/src/ggml-cuda/fattn-wmma-f16.cu
CHANGED
|
@@ -37,33 +37,13 @@ static __global__ void flash_attn_ext_f16(
|
|
| 37 |
const float m1,
|
| 38 |
const uint32_t n_head_log2,
|
| 39 |
const float logit_softcap,
|
| 40 |
-
const
|
| 41 |
-
|
| 42 |
-
const
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
const int ne13,
|
| 48 |
-
const int ne31,
|
| 49 |
-
const int ne32,
|
| 50 |
-
const int ne33,
|
| 51 |
-
const int nb31,
|
| 52 |
-
const int nb32,
|
| 53 |
-
const int nb33,
|
| 54 |
-
const int nb01,
|
| 55 |
-
const int nb02,
|
| 56 |
-
const int nb03,
|
| 57 |
-
const int nb11,
|
| 58 |
-
const int nb12,
|
| 59 |
-
const int nb13,
|
| 60 |
-
const int nb21,
|
| 61 |
-
const int nb22,
|
| 62 |
-
const int nb23,
|
| 63 |
-
const int ne0,
|
| 64 |
-
const int ne1,
|
| 65 |
-
const int ne2,
|
| 66 |
-
const int ne3) {
|
| 67 |
#if defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)))
|
| 68 |
// Skip unused kernel variants for faster compilation:
|
| 69 |
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
|
@@ -197,7 +177,7 @@ static __global__ void flash_attn_ext_f16(
|
|
| 197 |
#pragma unroll
|
| 198 |
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
|
| 199 |
frag_a_K K_a;
|
| 200 |
-
wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
|
| 201 |
#pragma unroll
|
| 202 |
for (int j = 0; j < ncols/frag_n; ++j) {
|
| 203 |
wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
|
|
@@ -344,7 +324,7 @@ static __global__ void flash_attn_ext_f16(
|
|
| 344 |
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
|
| 345 |
|
| 346 |
frag_a_V v_a;
|
| 347 |
-
wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
|
| 348 |
#pragma unroll
|
| 349 |
for (int j = 0; j < ncols/frag_n; ++j) {
|
| 350 |
wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
|
|
@@ -451,7 +431,6 @@ static __global__ void flash_attn_ext_f16(
|
|
| 451 |
GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
| 452 |
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
|
| 453 |
GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
|
| 454 |
-
GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);
|
| 455 |
NO_DEVICE_CODE;
|
| 456 |
#endif // defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)))
|
| 457 |
}
|
|
|
|
| 37 |
const float m1,
|
| 38 |
const uint32_t n_head_log2,
|
| 39 |
const float logit_softcap,
|
| 40 |
+
const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
|
| 41 |
+
const int32_t nb01, const int32_t nb02, const int32_t nb03,
|
| 42 |
+
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
|
| 43 |
+
const int32_t nb11, const int32_t nb12, const int64_t nb13,
|
| 44 |
+
const int32_t nb21, const int32_t nb22, const int64_t nb23,
|
| 45 |
+
const int32_t ne31, const int32_t ne32, const int32_t ne33,
|
| 46 |
+
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
#if defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)))
|
| 48 |
// Skip unused kernel variants for faster compilation:
|
| 49 |
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
|
|
|
| 177 |
#pragma unroll
|
| 178 |
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
|
| 179 |
frag_a_K K_a;
|
| 180 |
+
wmma::load_matrix_sync(K_a, K_h + int64_t(k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
|
| 181 |
#pragma unroll
|
| 182 |
for (int j = 0; j < ncols/frag_n; ++j) {
|
| 183 |
wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
|
|
|
|
| 324 |
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
|
| 325 |
|
| 326 |
frag_a_V v_a;
|
| 327 |
+
wmma::load_matrix_sync(v_a, V_h + int64_t(k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
|
| 328 |
#pragma unroll
|
| 329 |
for (int j = 0; j < ncols/frag_n; ++j) {
|
| 330 |
wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
|
|
|
|
| 431 |
GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
| 432 |
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
|
| 433 |
GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
|
|
|
|
| 434 |
NO_DEVICE_CODE;
|
| 435 |
#endif // defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)))
|
| 436 |
}
|
ggml/src/ggml-cuda/fattn.cu
CHANGED
|
@@ -280,22 +280,12 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|
| 280 |
const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
|
| 281 |
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
|
| 282 |
|
| 283 |
-
if (GGML_CUDA_CC_IS_AMD(cc)) {
|
| 284 |
#if defined(GGML_HIP_ROCWMMA_FATTN)
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
return;
|
| 288 |
-
}
|
| 289 |
-
#endif // defined(GGML_HIP_ROCWMMA_FATTN)
|
| 290 |
-
|
| 291 |
-
// On AMD the tile kernels perform poorly, use the vec kernel instead:
|
| 292 |
-
if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
|
| 293 |
-
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
| 294 |
-
} else {
|
| 295 |
-
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
|
| 296 |
-
}
|
| 297 |
return;
|
| 298 |
}
|
|
|
|
| 299 |
|
| 300 |
if (!fast_fp16_available(cc)) {
|
| 301 |
if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
|
|
|
|
| 280 |
const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
|
| 281 |
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
|
| 282 |
|
|
|
|
| 283 |
#if defined(GGML_HIP_ROCWMMA_FATTN)
|
| 284 |
+
if (GGML_CUDA_CC_IS_AMD(cc) && fp16_mma_available(cc)) {
|
| 285 |
+
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 286 |
return;
|
| 287 |
}
|
| 288 |
+
#endif // defined(GGML_HIP_ROCWMMA_FATTN)
|
| 289 |
|
| 290 |
if (!fast_fp16_available(cc)) {
|
| 291 |
if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
|