Spaces:
Running
Running
Commit
·
ace16dc
1
Parent(s):
f8fd66d
CUDA: faster Deepseek FA, add Turing support (llama/13435)
Browse files
ggml/src/ggml-cuda/fattn-common.cuh
CHANGED
|
@@ -678,10 +678,14 @@ void launch_fattn(
|
|
| 678 |
) {
|
| 679 |
constexpr int ncols = ncols1 * ncols2;
|
| 680 |
|
|
|
|
|
|
|
| 681 |
const ggml_tensor * Q = dst->src[0];
|
| 682 |
const ggml_tensor * K = dst->src[1];
|
| 683 |
const ggml_tensor * V = dst->src[2];
|
| 684 |
|
|
|
|
|
|
|
| 685 |
const ggml_tensor * mask = dst->src[3];
|
| 686 |
|
| 687 |
ggml_tensor * KQV = dst;
|
|
@@ -689,6 +693,10 @@ void launch_fattn(
|
|
| 689 |
GGML_ASSERT(Q->type == GGML_TYPE_F32);
|
| 690 |
GGML_ASSERT(KQV->type == GGML_TYPE_F32);
|
| 691 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 692 |
GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
|
| 693 |
GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
|
| 694 |
"the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
|
|
@@ -713,10 +721,10 @@ void launch_fattn(
|
|
| 713 |
size_t nb12 = K->nb[2];
|
| 714 |
size_t nb13 = K->nb[3];
|
| 715 |
|
| 716 |
-
const char * V_data = (const char *) V->data;
|
| 717 |
-
size_t nb21 = V->nb[1];
|
| 718 |
-
size_t nb22 = V->nb[2];
|
| 719 |
-
size_t nb23 = V->nb[3];
|
| 720 |
|
| 721 |
if (need_f16_K && K->type != GGML_TYPE_F16) {
|
| 722 |
GGML_ASSERT(ggml_is_contiguously_allocated(K));
|
|
@@ -733,7 +741,7 @@ void launch_fattn(
|
|
| 733 |
nb13 = nb13*bs*sizeof(half)/ts;
|
| 734 |
}
|
| 735 |
|
| 736 |
-
if (need_f16_V && V->type != GGML_TYPE_F16) {
|
| 737 |
GGML_ASSERT(ggml_is_contiguously_allocated(V));
|
| 738 |
V_f16.alloc(ggml_nelements(V));
|
| 739 |
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
|
|
|
|
| 678 |
) {
|
| 679 |
constexpr int ncols = ncols1 * ncols2;
|
| 680 |
|
| 681 |
+
const bool is_mla = DV == 512; // TODO better parameterization
|
| 682 |
+
|
| 683 |
const ggml_tensor * Q = dst->src[0];
|
| 684 |
const ggml_tensor * K = dst->src[1];
|
| 685 |
const ggml_tensor * V = dst->src[2];
|
| 686 |
|
| 687 |
+
GGML_ASSERT(V || is_mla);
|
| 688 |
+
|
| 689 |
const ggml_tensor * mask = dst->src[3];
|
| 690 |
|
| 691 |
ggml_tensor * KQV = dst;
|
|
|
|
| 693 |
GGML_ASSERT(Q->type == GGML_TYPE_F32);
|
| 694 |
GGML_ASSERT(KQV->type == GGML_TYPE_F32);
|
| 695 |
|
| 696 |
+
GGML_ASSERT( Q->nb[0] == ggml_element_size(Q));
|
| 697 |
+
GGML_ASSERT( K->nb[0] == ggml_element_size(K));
|
| 698 |
+
GGML_ASSERT(!V || V->nb[0] == ggml_element_size(V));
|
| 699 |
+
|
| 700 |
GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
|
| 701 |
GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
|
| 702 |
"the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
|
|
|
|
| 721 |
size_t nb12 = K->nb[2];
|
| 722 |
size_t nb13 = K->nb[3];
|
| 723 |
|
| 724 |
+
const char * V_data = V ? (const char *) V->data : nullptr;
|
| 725 |
+
size_t nb21 = V ? V->nb[1] : nb11;
|
| 726 |
+
size_t nb22 = V ? V->nb[2] : nb12;
|
| 727 |
+
size_t nb23 = V ? V->nb[3] : nb13;
|
| 728 |
|
| 729 |
if (need_f16_K && K->type != GGML_TYPE_F16) {
|
| 730 |
GGML_ASSERT(ggml_is_contiguously_allocated(K));
|
|
|
|
| 741 |
nb13 = nb13*bs*sizeof(half)/ts;
|
| 742 |
}
|
| 743 |
|
| 744 |
+
if (V && need_f16_V && V->type != GGML_TYPE_F16) {
|
| 745 |
GGML_ASSERT(ggml_is_contiguously_allocated(V));
|
| 746 |
V_f16.alloc(ggml_nelements(V));
|
| 747 |
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
|
ggml/src/ggml-cuda/fattn-mma-f16.cuh
CHANGED
|
@@ -33,9 +33,30 @@ struct fattn_mma_f16_config< 64, 64> {
|
|
| 33 |
static constexpr int nwarps_max = 4;
|
| 34 |
static constexpr bool Q_in_reg = true;
|
| 35 |
static constexpr int nstages_target = 2;
|
| 36 |
-
|
| 37 |
-
static
|
| 38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
};
|
| 40 |
|
| 41 |
template <>
|
|
@@ -44,9 +65,30 @@ struct fattn_mma_f16_config< 80, 80> {
|
|
| 44 |
static constexpr int nwarps_max = 4;
|
| 45 |
static constexpr bool Q_in_reg = true;
|
| 46 |
static constexpr int nstages_target = 2;
|
| 47 |
-
|
| 48 |
-
static
|
| 49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
};
|
| 51 |
|
| 52 |
template <>
|
|
@@ -55,9 +97,30 @@ struct fattn_mma_f16_config< 96, 96> {
|
|
| 55 |
static constexpr int nwarps_max = 4;
|
| 56 |
static constexpr bool Q_in_reg = true;
|
| 57 |
static constexpr int nstages_target = 2;
|
| 58 |
-
|
| 59 |
-
static
|
| 60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
};
|
| 62 |
|
| 63 |
template <>
|
|
@@ -66,9 +129,30 @@ struct fattn_mma_f16_config<112, 112> {
|
|
| 66 |
static constexpr int nwarps_max = 4;
|
| 67 |
static constexpr bool Q_in_reg = true;
|
| 68 |
static constexpr int nstages_target = 2;
|
| 69 |
-
|
| 70 |
-
static
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
};
|
| 73 |
|
| 74 |
template <>
|
|
@@ -77,9 +161,30 @@ struct fattn_mma_f16_config<128, 128> {
|
|
| 77 |
static constexpr int nwarps_max = 4;
|
| 78 |
static constexpr bool Q_in_reg = true;
|
| 79 |
static constexpr int nstages_target = 2;
|
| 80 |
-
|
| 81 |
-
static
|
| 82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
};
|
| 84 |
|
| 85 |
template <>
|
|
@@ -88,9 +193,38 @@ struct fattn_mma_f16_config<256, 256> {
|
|
| 88 |
static constexpr int nwarps_max = 4;
|
| 89 |
static constexpr bool Q_in_reg = true;
|
| 90 |
static constexpr int nstages_target = 2;
|
| 91 |
-
|
| 92 |
-
static
|
| 93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
};
|
| 95 |
|
| 96 |
template <>
|
|
@@ -99,9 +233,44 @@ struct fattn_mma_f16_config<576, 512> {
|
|
| 99 |
static constexpr int nwarps_max = 8;
|
| 100 |
static constexpr bool Q_in_reg = false;
|
| 101 |
static constexpr int nstages_target = 1;
|
| 102 |
-
|
| 103 |
-
static
|
| 104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
};
|
| 106 |
|
| 107 |
// ------------------------------------------------------------------------------------------------------------------
|
|
@@ -120,7 +289,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
|
|
| 120 |
|
| 121 |
const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared(tile_KV);
|
| 122 |
|
| 123 |
-
auto load = [&] __device__ (
|
| 124 |
const int stride_k = WARP_SIZE >> n;
|
| 125 |
const int k0_start = stride_k == WARP_SIZE ? 0 : chunks_per_row - chunks_per_row % (2*stride_k);
|
| 126 |
const int k0_stop = chunks_per_row - chunks_per_row % (1*stride_k);
|
|
@@ -223,7 +392,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
|
|
| 223 |
}
|
| 224 |
}
|
| 225 |
|
| 226 |
-
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool needs_fixup, bool is_fixup, bool last_iter>
|
| 227 |
static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
| 228 |
const float2 * const __restrict__ Q_f2,
|
| 229 |
const half2 * const __restrict__ K_h2,
|
|
@@ -261,10 +430,15 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
| 261 |
constexpr int cols_per_warp = ntiles * tile_B::I;
|
| 262 |
constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles;
|
| 263 |
constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
|
|
|
|
|
|
|
|
|
|
| 264 |
|
| 265 |
-
constexpr int stride_tile_Q = DKQ/2
|
| 266 |
-
constexpr int stride_tile_K =
|
| 267 |
-
|
|
|
|
|
|
|
| 268 |
|
| 269 |
const int k_VKQ_0 = kb0 * c::nbatch_fa;
|
| 270 |
tile_C_KQ KQ_C[c::nbatch_fa/(np*tile_C_KQ::I) * ntiles];
|
|
@@ -275,12 +449,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
| 275 |
tile_C_KQ_16 * KQ_C_16 = (tile_C_KQ_16 *) KQ_C;
|
| 276 |
|
| 277 |
if constexpr (nstages > 1) {
|
| 278 |
-
static_assert(
|
|
|
|
| 279 |
constexpr bool use_cp_async = true;
|
| 280 |
cp_async_wait_all();
|
| 281 |
__syncthreads();
|
| 282 |
flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
|
| 283 |
-
(V_h2 + k_VKQ_0*stride_V, tile_V,
|
| 284 |
} else {
|
| 285 |
constexpr bool use_cp_async = nstages == 1;
|
| 286 |
if (ncols2 > 1 || mask_h2) {
|
|
@@ -289,8 +464,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
| 289 |
}
|
| 290 |
|
| 291 |
#pragma unroll
|
| 292 |
-
for (int k0_start = 0; k0_start < DKQ/2; k0_start +=
|
| 293 |
-
const int k0_stop = k0_start +
|
| 294 |
const int k0_diff = k0_stop - k0_start;
|
| 295 |
|
| 296 |
if (nstages <= 1) {
|
|
@@ -537,16 +712,21 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
| 537 |
(mask_h2 + (k_VKQ_0 + c::nbatch_fa)/2, tile_mask, stride_mask);
|
| 538 |
}
|
| 539 |
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
|
| 540 |
-
(K_h2 + (k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K,
|
| 541 |
}
|
| 542 |
}
|
| 543 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 544 |
#pragma unroll
|
| 545 |
-
for (int
|
| 546 |
-
const int
|
| 547 |
-
const int i0_diff
|
| 548 |
|
| 549 |
-
if (nstages <= 1) {
|
| 550 |
constexpr bool use_cp_async = nstages == 1;
|
| 551 |
flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
|
| 552 |
(V_h2 + k_VKQ_0*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V);
|
|
@@ -555,6 +735,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
| 555 |
}
|
| 556 |
__syncthreads();
|
| 557 |
}
|
|
|
|
| 558 |
|
| 559 |
// Calculate VKQ tile:
|
| 560 |
#pragma unroll
|
|
@@ -565,7 +746,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
| 565 |
const int k0 = k00 + (threadIdx.y % np)*tile_A::J;
|
| 566 |
|
| 567 |
tile_A A;
|
| 568 |
-
load_ldmatrix_trans(A,
|
| 569 |
if (ntiles == 1) {
|
| 570 |
mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]);
|
| 571 |
} else {
|
|
@@ -596,7 +777,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
| 596 |
#endif // NEW_MMA_AVAILABLE
|
| 597 |
}
|
| 598 |
|
| 599 |
-
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool needs_fixup, bool is_fixup>
|
| 600 |
static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
| 601 |
const float2 * const __restrict__ Q_f2,
|
| 602 |
const half2 * const __restrict__ K_h2,
|
|
@@ -632,13 +813,16 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
| 632 |
constexpr int cols_per_warp = ntiles * tile_B::I;
|
| 633 |
constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles;
|
| 634 |
constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
|
|
|
|
|
|
|
| 635 |
|
| 636 |
static_assert(nwarps * (cols_per_warp/ncols2) % ncols1 == 0, "bad nwarps");
|
| 637 |
|
| 638 |
-
constexpr int stride_tile_Q = DKQ/2
|
| 639 |
-
constexpr int stride_tile_K =
|
| 640 |
-
constexpr int stride_tile_V = c::nbatch_V2 + 4;
|
| 641 |
|
|
|
|
|
|
|
| 642 |
constexpr int stride_tile_KV_max = stride_tile_K > stride_tile_V ? stride_tile_K : stride_tile_V;
|
| 643 |
|
| 644 |
extern __shared__ half2 tile_Q[];
|
|
@@ -726,26 +910,26 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
| 726 |
|
| 727 |
// Preload mask and K data for first iteration when using cp_async with multiple stages:
|
| 728 |
if constexpr (nstages > 1) {
|
| 729 |
-
static_assert(
|
| 730 |
constexpr bool use_cp_async = true;
|
| 731 |
if (ncols2 > 1 || mask_h2) {
|
| 732 |
flash_attn_ext_f16_load_mask<ncols1, nwarps, c::nbatch_fa, use_cp_async>
|
| 733 |
(mask_h2 + kb0_start*c::nbatch_fa/2, tile_mask, stride_mask);
|
| 734 |
}
|
| 735 |
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
|
| 736 |
-
(K_h2 + kb0_start*c::nbatch_fa*stride_K, tile_K,
|
| 737 |
}
|
| 738 |
|
| 739 |
// Iterate over ne11 == previous tokens:
|
| 740 |
for (int kb0 = kb0_start; kb0 < kb0_stop-1; ++kb0) {
|
| 741 |
constexpr bool last_iter = false;
|
| 742 |
-
flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup, last_iter>
|
| 743 |
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
|
| 744 |
ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
|
| 745 |
}
|
| 746 |
{ // kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
|
| 747 |
constexpr bool last_iter = true;
|
| 748 |
-
flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup, last_iter>
|
| 749 |
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
|
| 750 |
ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1);
|
| 751 |
}
|
|
@@ -774,7 +958,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
| 774 |
// It's also faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
|
| 775 |
// So also write VKQ accumulators to shared memory in column-major format if np == 1.
|
| 776 |
|
| 777 |
-
constexpr int nbatch_combine = c::
|
| 778 |
constexpr int tile_stride = nbatch_combine + 4;
|
| 779 |
static_assert((DV/2) % nbatch_combine == 0, "bad nbatch_combine");
|
| 780 |
|
|
@@ -1012,7 +1196,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
| 1012 |
#endif // NEW_MMA_AVAILABLE
|
| 1013 |
}
|
| 1014 |
|
| 1015 |
-
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap>
|
| 1016 |
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
| 1017 |
static __global__ void flash_attn_ext_f16(
|
| 1018 |
const char * __restrict__ Q,
|
|
@@ -1057,6 +1241,14 @@ static __global__ void flash_attn_ext_f16(
|
|
| 1057 |
NO_DEVICE_CODE;
|
| 1058 |
return;
|
| 1059 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1060 |
|
| 1061 |
typedef fattn_mma_f16_config<DKQ, DV> c;
|
| 1062 |
|
|
@@ -1067,9 +1259,10 @@ static __global__ void flash_attn_ext_f16(
|
|
| 1067 |
const int stride_Q1 = nb01 / sizeof(float2);
|
| 1068 |
const int stride_Q2 = nb02 / sizeof(float2);
|
| 1069 |
const int stride_K = nb11 / sizeof(half2);
|
| 1070 |
-
const int stride_V = nb21 / sizeof(half2);
|
| 1071 |
const int stride_mask = nb31 / sizeof(half2);
|
| 1072 |
|
|
|
|
|
|
|
| 1073 |
const int iter_k = ne11 / FATTN_KQ_STRIDE;
|
| 1074 |
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
|
| 1075 |
|
|
@@ -1092,10 +1285,11 @@ static __global__ void flash_attn_ext_f16(
|
|
| 1092 |
|
| 1093 |
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
|
| 1094 |
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
|
| 1095 |
-
const half2 * V_h2 = (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
|
| 1096 |
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
|
| 1097 |
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
|
| 1098 |
|
|
|
|
|
|
|
| 1099 |
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
|
| 1100 |
|
| 1101 |
const int kb0_start_kernel = kb0_start * kb_niter;
|
|
@@ -1104,12 +1298,12 @@ static __global__ void flash_attn_ext_f16(
|
|
| 1104 |
constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
|
| 1105 |
if (kb0_start == 0) {
|
| 1106 |
constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
|
| 1107 |
-
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup>
|
| 1108 |
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
|
| 1109 |
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
| 1110 |
} else {
|
| 1111 |
constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
|
| 1112 |
-
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup>
|
| 1113 |
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
|
| 1114 |
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
| 1115 |
}
|
|
@@ -1130,10 +1324,11 @@ static __global__ void flash_attn_ext_f16(
|
|
| 1130 |
|
| 1131 |
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
|
| 1132 |
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
|
| 1133 |
-
const half2 * V_h2 = (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio)); // K and V have same shape
|
| 1134 |
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
|
| 1135 |
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
|
| 1136 |
|
|
|
|
|
|
|
| 1137 |
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
|
| 1138 |
|
| 1139 |
const int kb0_start_kernel = kb0_start * kb_niter;
|
|
@@ -1141,7 +1336,7 @@ static __global__ void flash_attn_ext_f16(
|
|
| 1141 |
|
| 1142 |
constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
|
| 1143 |
constexpr bool needs_fixup = false;
|
| 1144 |
-
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup>
|
| 1145 |
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
|
| 1146 |
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
| 1147 |
#else
|
|
@@ -1167,10 +1362,6 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
|
|
| 1167 |
|
| 1168 |
typedef fattn_mma_f16_config<DKQ, DV> c;
|
| 1169 |
|
| 1170 |
-
constexpr int nbatch_K2 = c::nbatch_K2 < 1 ? DKQ/2 : c::nbatch_K2;
|
| 1171 |
-
constexpr int nbatch_V2 = c::nbatch_V2 < 1 ? DV /2 : c::nbatch_V2;
|
| 1172 |
-
constexpr int nbatch_combine = c::nbatch_combine < 1 ? DV /2 : c::nbatch_combine;
|
| 1173 |
-
|
| 1174 |
const int nstages = cp_async_available(cc) ? c::nstages_target : 0;
|
| 1175 |
|
| 1176 |
constexpr int ncols = ncols1 * ncols2;
|
|
@@ -1180,15 +1371,21 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
|
|
| 1180 |
constexpr int nwarps_max_y = c::nbatch_fa / tile_A::I;
|
| 1181 |
constexpr int nwarps = nwarps_max_x*nwarps_max_y <= c::nwarps_max ? nwarps_max_x*nwarps_max_y : c::nwarps_max;
|
| 1182 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1183 |
static_assert(DKQ % tile_B::J == 0, "bad DKQ");
|
| 1184 |
static_assert(DV % tile_A::J == 0, "bad DV");
|
| 1185 |
static_assert(ncols % cols_per_warp == 0, "bad ncols");
|
| 1186 |
|
| 1187 |
-
const size_t nbytes_shared_KV_1stage = c::nbatch_fa * std::max(
|
| 1188 |
-
const size_t nbytes_shared_KV_2stage = c::nbatch_fa * (
|
| 1189 |
-
const size_t nbytes_shared_Q = ncols * (DKQ/2 + 4)
|
| 1190 |
-
const size_t nbytes_shared_mask = ncols1 * (c::nbatch_fa/2 + 4)
|
| 1191 |
-
const size_t nbytes_shared_combine = nwarps*cols_per_warp * (nbatch_combine + 4)
|
| 1192 |
|
| 1193 |
const size_t nbytes_shared_KV = nstages <= 1 ? nbytes_shared_KV_1stage : nbytes_shared_KV_2stage;
|
| 1194 |
|
|
@@ -1202,7 +1399,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
|
|
| 1202 |
fattn_kernel_t fattn_kernel;
|
| 1203 |
if (logit_softcap == 0.0f) {
|
| 1204 |
constexpr bool use_logit_softcap = false;
|
| 1205 |
-
fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap>;
|
| 1206 |
|
| 1207 |
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
| 1208 |
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
|
|
@@ -1213,7 +1410,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
|
|
| 1213 |
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
| 1214 |
} else {
|
| 1215 |
constexpr bool use_logit_softcap = true;
|
| 1216 |
-
fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap>;
|
| 1217 |
|
| 1218 |
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
| 1219 |
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
|
|
|
|
| 33 |
static constexpr int nwarps_max = 4;
|
| 34 |
static constexpr bool Q_in_reg = true;
|
| 35 |
static constexpr int nstages_target = 2;
|
| 36 |
+
|
| 37 |
+
static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
|
| 38 |
+
return 32;
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
|
| 42 |
+
return 32;
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
|
| 46 |
+
return 32;
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
|
| 50 |
+
return 32;
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
|
| 54 |
+
return 32;
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
|
| 58 |
+
return 32;
|
| 59 |
+
}
|
| 60 |
};
|
| 61 |
|
| 62 |
template <>
|
|
|
|
| 65 |
static constexpr int nwarps_max = 4;
|
| 66 |
static constexpr bool Q_in_reg = true;
|
| 67 |
static constexpr int nstages_target = 2;
|
| 68 |
+
|
| 69 |
+
static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
|
| 70 |
+
return 40;
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
|
| 74 |
+
return 40;
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
|
| 78 |
+
return 40;
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
|
| 82 |
+
return 40;
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
|
| 86 |
+
return 40;
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
|
| 90 |
+
return 40;
|
| 91 |
+
}
|
| 92 |
};
|
| 93 |
|
| 94 |
template <>
|
|
|
|
| 97 |
static constexpr int nwarps_max = 4;
|
| 98 |
static constexpr bool Q_in_reg = true;
|
| 99 |
static constexpr int nstages_target = 2;
|
| 100 |
+
|
| 101 |
+
static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
|
| 102 |
+
return 48;
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
|
| 106 |
+
return 48;
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
|
| 110 |
+
return 48;
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
|
| 114 |
+
return 48;
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
|
| 118 |
+
return 48;
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
|
| 122 |
+
return 48;
|
| 123 |
+
}
|
| 124 |
};
|
| 125 |
|
| 126 |
template <>
|
|
|
|
| 129 |
static constexpr int nwarps_max = 4;
|
| 130 |
static constexpr bool Q_in_reg = true;
|
| 131 |
static constexpr int nstages_target = 2;
|
| 132 |
+
|
| 133 |
+
static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
|
| 134 |
+
return 56;
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
|
| 138 |
+
return 56;
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
|
| 142 |
+
return 56;
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
|
| 146 |
+
return 56;
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
|
| 150 |
+
return 56;
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
|
| 154 |
+
return 56;
|
| 155 |
+
}
|
| 156 |
};
|
| 157 |
|
| 158 |
template <>
|
|
|
|
| 161 |
static constexpr int nwarps_max = 4;
|
| 162 |
static constexpr bool Q_in_reg = true;
|
| 163 |
static constexpr int nstages_target = 2;
|
| 164 |
+
|
| 165 |
+
static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
|
| 166 |
+
return 64;
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
|
| 170 |
+
return 64;
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
|
| 174 |
+
return 64;
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
|
| 178 |
+
return 64;
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
|
| 182 |
+
return 64;
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
|
| 186 |
+
return 64;
|
| 187 |
+
}
|
| 188 |
};
|
| 189 |
|
| 190 |
template <>
|
|
|
|
| 193 |
static constexpr int nwarps_max = 4;
|
| 194 |
static constexpr bool Q_in_reg = true;
|
| 195 |
static constexpr int nstages_target = 2;
|
| 196 |
+
|
| 197 |
+
static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
|
| 198 |
+
return 128;
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
|
| 202 |
+
return 128;
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
|
| 206 |
+
return 128;
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
|
| 210 |
+
return 128;
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
static int get_nbatch_combine_host(const int cc, const int ncols) {
|
| 214 |
+
if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) {
|
| 215 |
+
return ncols <= 16 ? 128 : 64;
|
| 216 |
+
}
|
| 217 |
+
return 64;
|
| 218 |
+
}
|
| 219 |
+
|
| 220 |
+
static constexpr __device__ int get_nbatch_combine_device(int ncols) {
|
| 221 |
+
#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
| 222 |
+
return ncols <= 16 ? 128 : 64;
|
| 223 |
+
#else
|
| 224 |
+
GGML_UNUSED(ncols);
|
| 225 |
+
return 128;
|
| 226 |
+
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
| 227 |
+
}
|
| 228 |
};
|
| 229 |
|
| 230 |
template <>
|
|
|
|
| 233 |
static constexpr int nwarps_max = 8;
|
| 234 |
static constexpr bool Q_in_reg = false;
|
| 235 |
static constexpr int nstages_target = 1;
|
| 236 |
+
|
| 237 |
+
static int get_nbatch_K2_host(const int cc, const int ncols) {
|
| 238 |
+
if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) {
|
| 239 |
+
return ncols <= 16 ? 96 : 160;
|
| 240 |
+
}
|
| 241 |
+
return ncols <= 16 ? 288 : 160;
|
| 242 |
+
}
|
| 243 |
+
|
| 244 |
+
static constexpr __device__ int get_nbatch_K2_device(int ncols) {
|
| 245 |
+
#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
| 246 |
+
return ncols <= 16 ? 96 : 160;
|
| 247 |
+
#else
|
| 248 |
+
return ncols <= 16 ? 288 : 160;
|
| 249 |
+
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
| 250 |
+
}
|
| 251 |
+
|
| 252 |
+
static int get_nbatch_V2_host(const int cc, const int ncols) {
|
| 253 |
+
if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) {
|
| 254 |
+
return ncols <= 16 ? 64 : 128;
|
| 255 |
+
}
|
| 256 |
+
return ncols <= 16 ? 256 : 128;
|
| 257 |
+
}
|
| 258 |
+
|
| 259 |
+
static constexpr __device__ int get_nbatch_V2_device(int ncols) {
|
| 260 |
+
#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
| 261 |
+
return ncols <= 16 ? 64 : 128;
|
| 262 |
+
#else
|
| 263 |
+
return ncols <= 16 ? 256 : 128;
|
| 264 |
+
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
| 265 |
+
}
|
| 266 |
+
|
| 267 |
+
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
|
| 268 |
+
return 128;
|
| 269 |
+
}
|
| 270 |
+
|
| 271 |
+
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
|
| 272 |
+
return 128;
|
| 273 |
+
}
|
| 274 |
};
|
| 275 |
|
| 276 |
// ------------------------------------------------------------------------------------------------------------------
|
|
|
|
| 289 |
|
| 290 |
const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared(tile_KV);
|
| 291 |
|
| 292 |
+
auto load = [&] __device__ (auto n) {
|
| 293 |
const int stride_k = WARP_SIZE >> n;
|
| 294 |
const int k0_start = stride_k == WARP_SIZE ? 0 : chunks_per_row - chunks_per_row % (2*stride_k);
|
| 295 |
const int k0_stop = chunks_per_row - chunks_per_row % (1*stride_k);
|
|
|
|
| 392 |
}
|
| 393 |
}
|
| 394 |
|
| 395 |
+
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup, bool last_iter>
|
| 396 |
static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
| 397 |
const float2 * const __restrict__ Q_f2,
|
| 398 |
const half2 * const __restrict__ K_h2,
|
|
|
|
| 430 |
constexpr int cols_per_warp = ntiles * tile_B::I;
|
| 431 |
constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles;
|
| 432 |
constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
|
| 433 |
+
constexpr int ncols = ncols1 * ncols2;
|
| 434 |
+
constexpr int nbatch_K2 = c::get_nbatch_K2_device(ncols);
|
| 435 |
+
constexpr int nbatch_V2 = c::get_nbatch_V2_device(ncols);
|
| 436 |
|
| 437 |
+
constexpr int stride_tile_Q = DKQ/2 + 4;
|
| 438 |
+
constexpr int stride_tile_K = nbatch_K2 + 4;
|
| 439 |
+
|
| 440 |
+
static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA");
|
| 441 |
+
constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4;
|
| 442 |
|
| 443 |
const int k_VKQ_0 = kb0 * c::nbatch_fa;
|
| 444 |
tile_C_KQ KQ_C[c::nbatch_fa/(np*tile_C_KQ::I) * ntiles];
|
|
|
|
| 449 |
tile_C_KQ_16 * KQ_C_16 = (tile_C_KQ_16 *) KQ_C;
|
| 450 |
|
| 451 |
if constexpr (nstages > 1) {
|
| 452 |
+
static_assert(!mla, "multi-stage loading not implemented for MLA");
|
| 453 |
+
static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading");
|
| 454 |
constexpr bool use_cp_async = true;
|
| 455 |
cp_async_wait_all();
|
| 456 |
__syncthreads();
|
| 457 |
flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
|
| 458 |
+
(V_h2 + k_VKQ_0*stride_V, tile_V, nbatch_V2, stride_V);
|
| 459 |
} else {
|
| 460 |
constexpr bool use_cp_async = nstages == 1;
|
| 461 |
if (ncols2 > 1 || mask_h2) {
|
|
|
|
| 464 |
}
|
| 465 |
|
| 466 |
#pragma unroll
|
| 467 |
+
for (int k0_start = 0; k0_start < DKQ/2; k0_start += nbatch_K2) {
|
| 468 |
+
const int k0_stop = k0_start + nbatch_K2 < DKQ/2 ? k0_start + nbatch_K2 : DKQ/2;
|
| 469 |
const int k0_diff = k0_stop - k0_start;
|
| 470 |
|
| 471 |
if (nstages <= 1) {
|
|
|
|
| 712 |
(mask_h2 + (k_VKQ_0 + c::nbatch_fa)/2, tile_mask, stride_mask);
|
| 713 |
}
|
| 714 |
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
|
| 715 |
+
(K_h2 + (k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K);
|
| 716 |
}
|
| 717 |
}
|
| 718 |
|
| 719 |
+
|
| 720 |
+
// For MLA K and V have the same data.
|
| 721 |
+
// Therefore, iterate over V in reverse and re-use the data if possible.
|
| 722 |
+
static_assert(!mla || nstages <= 1, "combination of MLA and multi-stage loading not implemented");
|
| 723 |
+
constexpr int reusable_cutoff = mla ? (DKQ - 1) - (DKQ - 1) % (2*nbatch_K2) - (DKQ - DV) : DV;
|
| 724 |
#pragma unroll
|
| 725 |
+
for (int i0_stop = DV; i0_stop > 0; i0_stop -= 2*nbatch_V2) {
|
| 726 |
+
const int i0_start = i0_stop - 2*nbatch_V2 > 0 ? i0_stop - 2*nbatch_V2 : 0;
|
| 727 |
+
const int i0_diff = i0_stop - i0_start;
|
| 728 |
|
| 729 |
+
if (nstages <= 1 && i0_start < reusable_cutoff) {
|
| 730 |
constexpr bool use_cp_async = nstages == 1;
|
| 731 |
flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
|
| 732 |
(V_h2 + k_VKQ_0*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V);
|
|
|
|
| 735 |
}
|
| 736 |
__syncthreads();
|
| 737 |
}
|
| 738 |
+
const half2 * tile_V_i = i0_start < reusable_cutoff ? tile_V : tile_V + (i0_start - reusable_cutoff)/2;
|
| 739 |
|
| 740 |
// Calculate VKQ tile:
|
| 741 |
#pragma unroll
|
|
|
|
| 746 |
const int k0 = k00 + (threadIdx.y % np)*tile_A::J;
|
| 747 |
|
| 748 |
tile_A A;
|
| 749 |
+
load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
|
| 750 |
if (ntiles == 1) {
|
| 751 |
mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]);
|
| 752 |
} else {
|
|
|
|
| 777 |
#endif // NEW_MMA_AVAILABLE
|
| 778 |
}
|
| 779 |
|
| 780 |
+
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup>
|
| 781 |
static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
| 782 |
const float2 * const __restrict__ Q_f2,
|
| 783 |
const half2 * const __restrict__ K_h2,
|
|
|
|
| 813 |
constexpr int cols_per_warp = ntiles * tile_B::I;
|
| 814 |
constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles;
|
| 815 |
constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
|
| 816 |
+
constexpr int nbatch_K2 = c::get_nbatch_K2_device(ncols);
|
| 817 |
+
constexpr int nbatch_V2 = c::get_nbatch_V2_device(ncols);
|
| 818 |
|
| 819 |
static_assert(nwarps * (cols_per_warp/ncols2) % ncols1 == 0, "bad nwarps");
|
| 820 |
|
| 821 |
+
constexpr int stride_tile_Q = DKQ/2 + 4;
|
| 822 |
+
constexpr int stride_tile_K = nbatch_K2 + 4;
|
|
|
|
| 823 |
|
| 824 |
+
static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA");
|
| 825 |
+
constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4;
|
| 826 |
constexpr int stride_tile_KV_max = stride_tile_K > stride_tile_V ? stride_tile_K : stride_tile_V;
|
| 827 |
|
| 828 |
extern __shared__ half2 tile_Q[];
|
|
|
|
| 910 |
|
| 911 |
// Preload mask and K data for first iteration when using cp_async with multiple stages:
|
| 912 |
if constexpr (nstages > 1) {
|
| 913 |
+
static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi-stage pipeline");
|
| 914 |
constexpr bool use_cp_async = true;
|
| 915 |
if (ncols2 > 1 || mask_h2) {
|
| 916 |
flash_attn_ext_f16_load_mask<ncols1, nwarps, c::nbatch_fa, use_cp_async>
|
| 917 |
(mask_h2 + kb0_start*c::nbatch_fa/2, tile_mask, stride_mask);
|
| 918 |
}
|
| 919 |
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
|
| 920 |
+
(K_h2 + kb0_start*c::nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K);
|
| 921 |
}
|
| 922 |
|
| 923 |
// Iterate over ne11 == previous tokens:
|
| 924 |
for (int kb0 = kb0_start; kb0 < kb0_stop-1; ++kb0) {
|
| 925 |
constexpr bool last_iter = false;
|
| 926 |
+
flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
|
| 927 |
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
|
| 928 |
ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
|
| 929 |
}
|
| 930 |
{ // kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
|
| 931 |
constexpr bool last_iter = true;
|
| 932 |
+
flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
|
| 933 |
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
|
| 934 |
ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1);
|
| 935 |
}
|
|
|
|
| 958 |
// It's also faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
|
| 959 |
// So also write VKQ accumulators to shared memory in column-major format if np == 1.
|
| 960 |
|
| 961 |
+
constexpr int nbatch_combine = c::get_nbatch_combine_device(ncols);
|
| 962 |
constexpr int tile_stride = nbatch_combine + 4;
|
| 963 |
static_assert((DV/2) % nbatch_combine == 0, "bad nbatch_combine");
|
| 964 |
|
|
|
|
| 1196 |
#endif // NEW_MMA_AVAILABLE
|
| 1197 |
}
|
| 1198 |
|
| 1199 |
+
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla>
|
| 1200 |
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
| 1201 |
static __global__ void flash_attn_ext_f16(
|
| 1202 |
const char * __restrict__ Q,
|
|
|
|
| 1241 |
NO_DEVICE_CODE;
|
| 1242 |
return;
|
| 1243 |
}
|
| 1244 |
+
#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
| 1245 |
+
if (ncols1*ncols2 > 32) {
|
| 1246 |
+
NO_DEVICE_CODE;
|
| 1247 |
+
return;
|
| 1248 |
+
}
|
| 1249 |
+
#endif __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
| 1250 |
+
|
| 1251 |
+
static_assert(!mla || DKQ >= DV, "MLA needs DKQ >= DV");
|
| 1252 |
|
| 1253 |
typedef fattn_mma_f16_config<DKQ, DV> c;
|
| 1254 |
|
|
|
|
| 1259 |
const int stride_Q1 = nb01 / sizeof(float2);
|
| 1260 |
const int stride_Q2 = nb02 / sizeof(float2);
|
| 1261 |
const int stride_K = nb11 / sizeof(half2);
|
|
|
|
| 1262 |
const int stride_mask = nb31 / sizeof(half2);
|
| 1263 |
|
| 1264 |
+
const int stride_V = mla ? stride_K : nb21 / sizeof(half2);
|
| 1265 |
+
|
| 1266 |
const int iter_k = ne11 / FATTN_KQ_STRIDE;
|
| 1267 |
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
|
| 1268 |
|
|
|
|
| 1285 |
|
| 1286 |
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
|
| 1287 |
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
|
|
|
|
| 1288 |
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
|
| 1289 |
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
|
| 1290 |
|
| 1291 |
+
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
|
| 1292 |
+
|
| 1293 |
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
|
| 1294 |
|
| 1295 |
const int kb0_start_kernel = kb0_start * kb_niter;
|
|
|
|
| 1298 |
constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
|
| 1299 |
if (kb0_start == 0) {
|
| 1300 |
constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
|
| 1301 |
+
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
|
| 1302 |
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
|
| 1303 |
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
| 1304 |
} else {
|
| 1305 |
constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
|
| 1306 |
+
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
|
| 1307 |
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
|
| 1308 |
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
| 1309 |
}
|
|
|
|
| 1324 |
|
| 1325 |
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
|
| 1326 |
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
|
|
|
|
| 1327 |
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
|
| 1328 |
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
|
| 1329 |
|
| 1330 |
+
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
|
| 1331 |
+
|
| 1332 |
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
|
| 1333 |
|
| 1334 |
const int kb0_start_kernel = kb0_start * kb_niter;
|
|
|
|
| 1336 |
|
| 1337 |
constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
|
| 1338 |
constexpr bool needs_fixup = false;
|
| 1339 |
+
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
|
| 1340 |
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
|
| 1341 |
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
| 1342 |
#else
|
|
|
|
| 1362 |
|
| 1363 |
typedef fattn_mma_f16_config<DKQ, DV> c;
|
| 1364 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1365 |
const int nstages = cp_async_available(cc) ? c::nstages_target : 0;
|
| 1366 |
|
| 1367 |
constexpr int ncols = ncols1 * ncols2;
|
|
|
|
| 1371 |
constexpr int nwarps_max_y = c::nbatch_fa / tile_A::I;
|
| 1372 |
constexpr int nwarps = nwarps_max_x*nwarps_max_y <= c::nwarps_max ? nwarps_max_x*nwarps_max_y : c::nwarps_max;
|
| 1373 |
|
| 1374 |
+
constexpr bool mla = DKQ == 576;
|
| 1375 |
+
|
| 1376 |
+
const int nbatch_K2 = c::get_nbatch_K2_host (cc, ncols);
|
| 1377 |
+
const int nbatch_V2 = c::get_nbatch_K2_host (cc, ncols);
|
| 1378 |
+
const int nbatch_combine = c::get_nbatch_combine_host(cc, ncols);
|
| 1379 |
+
|
| 1380 |
static_assert(DKQ % tile_B::J == 0, "bad DKQ");
|
| 1381 |
static_assert(DV % tile_A::J == 0, "bad DV");
|
| 1382 |
static_assert(ncols % cols_per_warp == 0, "bad ncols");
|
| 1383 |
|
| 1384 |
+
const size_t nbytes_shared_KV_1stage = c::nbatch_fa * std::max(nbatch_K2 + 4, nbatch_V2 + 4) * sizeof(half2);
|
| 1385 |
+
const size_t nbytes_shared_KV_2stage = c::nbatch_fa * (nbatch_K2 + 4 + nbatch_V2 + 4) * sizeof(half2);
|
| 1386 |
+
const size_t nbytes_shared_Q = ncols * (DKQ/2 + 4) * sizeof(half2);
|
| 1387 |
+
const size_t nbytes_shared_mask = ncols1 * (c::nbatch_fa/2 + 4) * sizeof(half2);
|
| 1388 |
+
const size_t nbytes_shared_combine = nwarps*cols_per_warp * (nbatch_combine + 4) * sizeof(half2);
|
| 1389 |
|
| 1390 |
const size_t nbytes_shared_KV = nstages <= 1 ? nbytes_shared_KV_1stage : nbytes_shared_KV_2stage;
|
| 1391 |
|
|
|
|
| 1399 |
fattn_kernel_t fattn_kernel;
|
| 1400 |
if (logit_softcap == 0.0f) {
|
| 1401 |
constexpr bool use_logit_softcap = false;
|
| 1402 |
+
fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla>;
|
| 1403 |
|
| 1404 |
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
| 1405 |
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
|
|
|
|
| 1410 |
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
| 1411 |
} else {
|
| 1412 |
constexpr bool use_logit_softcap = true;
|
| 1413 |
+
fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla>;
|
| 1414 |
|
| 1415 |
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
| 1416 |
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
|
ggml/src/ggml-cuda/fattn.cu
CHANGED
|
@@ -10,6 +10,7 @@
|
|
| 10 |
|
| 11 |
template <int DKQ, int DV, int ncols2>
|
| 12 |
static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
|
|
| 13 |
const ggml_tensor * Q = dst->src[0];
|
| 14 |
|
| 15 |
if constexpr (ncols2 <= 8) {
|
|
@@ -24,7 +25,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_con
|
|
| 24 |
return;
|
| 25 |
}
|
| 26 |
|
| 27 |
-
if (Q->ne[1] <= 32/ncols2) {
|
| 28 |
ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 32/ncols2, ncols2>(ctx, dst);
|
| 29 |
return;
|
| 30 |
}
|
|
|
|
| 10 |
|
| 11 |
template <int DKQ, int DV, int ncols2>
|
| 12 |
static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 13 |
+
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
| 14 |
const ggml_tensor * Q = dst->src[0];
|
| 15 |
|
| 16 |
if constexpr (ncols2 <= 8) {
|
|
|
|
| 25 |
return;
|
| 26 |
}
|
| 27 |
|
| 28 |
+
if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING || Q->ne[1] <= 32/ncols2) {
|
| 29 |
ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 32/ncols2, ncols2>(ctx, dst);
|
| 30 |
return;
|
| 31 |
}
|
ggml/src/ggml-cuda/ggml-cuda.cu
CHANGED
|
@@ -3222,7 +3222,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|
| 3222 |
#endif // FLASH_ATTN_AVAILABLE
|
| 3223 |
if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
|
| 3224 |
const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
|
| 3225 |
-
if (!new_mma_available(cc)
|
| 3226 |
return false;
|
| 3227 |
}
|
| 3228 |
const int gqa_ratio = op->src[0]->ne[2] / op->src[1]->ne[2];
|
|
|
|
| 3222 |
#endif // FLASH_ATTN_AVAILABLE
|
| 3223 |
if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
|
| 3224 |
const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
|
| 3225 |
+
if (!new_mma_available(cc)) {
|
| 3226 |
return false;
|
| 3227 |
}
|
| 3228 |
const int gqa_ratio = op->src[0]->ne[2] / op->src[1]->ne[2];
|