Spaces:
Running
Running
Commit
·
507d30c
1
Parent(s):
2eca371
CUDA: FA support for Deepseek (Ampere or newer) (llama/13306)
Browse files* CUDA: FA support for Deepseek (Ampere or newer)
* do loop unrolling via C++ template
- ggml/src/ggml-cuda/CMakeLists.txt +1 -1
- ggml/src/ggml-cuda/common.cuh +19 -0
- ggml/src/ggml-cuda/cp-async.cuh +11 -0
- ggml/src/ggml-cuda/fattn-common.cuh +13 -13
- ggml/src/ggml-cuda/fattn-mma-f16.cuh +564 -344
- ggml/src/ggml-cuda/fattn-tile-f16.cu +2 -2
- ggml/src/ggml-cuda/fattn-tile-f32.cu +2 -2
- ggml/src/ggml-cuda/fattn-vec-f16.cuh +1 -1
- ggml/src/ggml-cuda/fattn-vec-f32.cuh +1 -1
- ggml/src/ggml-cuda/fattn-wmma-f16.cu +1 -1
- ggml/src/ggml-cuda/fattn.cu +71 -45
- ggml/src/ggml-cuda/ggml-cuda.cu +6 -6
- ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +5 -0
- ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +6 -6
- ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu +6 -6
- ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu +6 -6
- ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +6 -6
- ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +5 -0
- ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +6 -6
- ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +6 -6
- ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu +6 -6
- ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu +6 -6
- ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +5 -0
- ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu +6 -6
- ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +6 -6
- ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +6 -6
- ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu +6 -6
- ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu +6 -6
- ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu +6 -6
- ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +6 -6
- ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +6 -6
- ggml/src/ggml-cuda/template-instances/generate_cu_files.py +12 -9
ggml/src/ggml-cuda/CMakeLists.txt
CHANGED
|
@@ -118,7 +118,7 @@ if (CUDAToolkit_FOUND)
|
|
| 118 |
|
| 119 |
set(CUDA_CXX_FLAGS "")
|
| 120 |
|
| 121 |
-
set(CUDA_FLAGS -use_fast_math)
|
| 122 |
|
| 123 |
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "12.8")
|
| 124 |
# Options are:
|
|
|
|
| 118 |
|
| 119 |
set(CUDA_CXX_FLAGS "")
|
| 120 |
|
| 121 |
+
set(CUDA_FLAGS -use_fast_math -extended-lambda)
|
| 122 |
|
| 123 |
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "12.8")
|
| 124 |
# Options are:
|
ggml/src/ggml-cuda/common.cuh
CHANGED
|
@@ -296,6 +296,25 @@ static __device__ void no_device_code(
|
|
| 296 |
#define NO_DEVICE_CODE //GGML_ABORT("NO_DEVICE_CODE not valid in host code.")
|
| 297 |
#endif // __CUDA_ARCH__
|
| 298 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 299 |
template<int width = WARP_SIZE>
|
| 300 |
static __device__ __forceinline__ int warp_reduce_sum(int x) {
|
| 301 |
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
|
|
|
| 296 |
#define NO_DEVICE_CODE //GGML_ABORT("NO_DEVICE_CODE not valid in host code.")
|
| 297 |
#endif // __CUDA_ARCH__
|
| 298 |
|
| 299 |
+
// The compiler is always able to unroll loops if they contain continue expressions.
|
| 300 |
+
// In such cases loop unrolling can still be achieved via recursion:
|
| 301 |
+
template <int n>
|
| 302 |
+
struct ggml_cuda_unroll {
|
| 303 |
+
template <typename Func, typename... Args>
|
| 304 |
+
__device__ void operator()(const Func & f, Args... args) const {
|
| 305 |
+
f(n - 1, args...);
|
| 306 |
+
ggml_cuda_unroll<n - 1>{}(f, args...);
|
| 307 |
+
}
|
| 308 |
+
};
|
| 309 |
+
|
| 310 |
+
template <>
|
| 311 |
+
struct ggml_cuda_unroll<1> {
|
| 312 |
+
template <typename Func, typename... Args>
|
| 313 |
+
__device__ void operator()(const Func & f, Args... args) const {
|
| 314 |
+
f(0, args...);
|
| 315 |
+
}
|
| 316 |
+
};
|
| 317 |
+
|
| 318 |
template<int width = WARP_SIZE>
|
| 319 |
static __device__ __forceinline__ int warp_reduce_sum(int x) {
|
| 320 |
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
ggml/src/ggml-cuda/cp-async.cuh
CHANGED
|
@@ -2,6 +2,17 @@
|
|
| 2 |
|
| 3 |
#include "common.cuh"
|
| 4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
// Copies data from global to shared memory, cg == cache global.
|
| 6 |
// Both the src and dst pointers must be aligned to 16 bit.
|
| 7 |
// Shared memory uses 32 bit addressing, the pointer is passed as unsigned int.
|
|
|
|
| 2 |
|
| 3 |
#include "common.cuh"
|
| 4 |
|
| 5 |
+
|
| 6 |
+
static __device__ __forceinline__ unsigned int ggml_cuda_cvta_generic_to_shared(void * generic_ptr) {
|
| 7 |
+
#ifdef CP_ASYNC_AVAILABLE
|
| 8 |
+
return __cvta_generic_to_shared(generic_ptr);
|
| 9 |
+
#else
|
| 10 |
+
GGML_UNUSED(generic_ptr);
|
| 11 |
+
NO_DEVICE_CODE;
|
| 12 |
+
return 0;
|
| 13 |
+
#endif // CP_ASYNC_AVAILABLE
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
// Copies data from global to shared memory, cg == cache global.
|
| 17 |
// Both the src and dst pointers must be aligned to 16 bit.
|
| 18 |
// Shared memory uses 32 bit addressing, the pointer is passed as unsigned int.
|
ggml/src/ggml-cuda/fattn-common.cuh
CHANGED
|
@@ -516,7 +516,7 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
|
|
| 516 |
nullptr;
|
| 517 |
}
|
| 518 |
|
| 519 |
-
template<int D, int ncols1, int ncols2
|
| 520 |
__launch_bounds__(D, 1)
|
| 521 |
static __global__ void flash_attn_stream_k_fixup(
|
| 522 |
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) {
|
|
@@ -665,13 +665,13 @@ static void on_no_fattn_vec_case(const int D) {
|
|
| 665 |
fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for all combinations of q4_0, q4_1, q5_0, q5_1, q8_0, and f16.\n");
|
| 666 |
GGML_ABORT("fatal error");
|
| 667 |
} else {
|
| 668 |
-
fprintf(stderr, "Unsupported KV type combination for head_size
|
| 669 |
fprintf(stderr, "Only f16 is supported.\n");
|
| 670 |
GGML_ABORT("fatal error");
|
| 671 |
}
|
| 672 |
}
|
| 673 |
|
| 674 |
-
template <int
|
| 675 |
void launch_fattn(
|
| 676 |
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared,
|
| 677 |
const int KQ_row_granularity, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE
|
|
@@ -691,7 +691,7 @@ void launch_fattn(
|
|
| 691 |
|
| 692 |
GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
|
| 693 |
GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
|
| 694 |
-
|
| 695 |
|
| 696 |
GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
|
| 697 |
|
|
@@ -754,10 +754,13 @@ void launch_fattn(
|
|
| 754 |
const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];
|
| 755 |
|
| 756 |
const dim3 block_dim(warp_size, nwarps, 1);
|
|
|
|
|
|
|
|
|
|
| 757 |
dim3 blocks_num;
|
| 758 |
if (stream_k) {
|
| 759 |
// For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
|
| 760 |
-
const int max_blocks =
|
| 761 |
const int tiles_nwaves = (ntiles_total + max_blocks - 1) / max_blocks;
|
| 762 |
const int tiles_efficiency_percent = 100 * ntiles_total / (max_blocks*tiles_nwaves);
|
| 763 |
|
|
@@ -769,14 +772,11 @@ void launch_fattn(
|
|
| 769 |
blocks_num.y = 1;
|
| 770 |
blocks_num.z = 1;
|
| 771 |
|
| 772 |
-
dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 +
|
| 773 |
} else {
|
| 774 |
GGML_ASSERT(K->ne[1] % KQ_row_granularity == 0);
|
| 775 |
const int ntiles_KQ = K->ne[1] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size.
|
| 776 |
|
| 777 |
-
int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy.
|
| 778 |
-
CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared));
|
| 779 |
-
|
| 780 |
// parallel_blocks should be at least large enough to achieve max. occupancy for a single wave:
|
| 781 |
parallel_blocks = std::max((nsm * max_blocks_per_sm) / ntiles_total, 1);
|
| 782 |
|
|
@@ -853,19 +853,19 @@ void launch_fattn(
|
|
| 853 |
|
| 854 |
if (stream_k) {
|
| 855 |
if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
|
| 856 |
-
const dim3 block_dim_combine(
|
| 857 |
const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2};
|
| 858 |
|
| 859 |
-
flash_attn_stream_k_fixup<
|
| 860 |
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
|
| 861 |
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
|
| 862 |
}
|
| 863 |
} else if (parallel_blocks > 1) {
|
| 864 |
-
const dim3 block_dim_combine(
|
| 865 |
const dim3 blocks_num_combine(Q->ne[1], 1, blocks_num.z);
|
| 866 |
const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2);
|
| 867 |
|
| 868 |
-
flash_attn_combine_results<
|
| 869 |
<<<blocks_num_combine, block_dim_combine, nbytes_shared_combine, main_stream>>>
|
| 870 |
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data, parallel_blocks);
|
| 871 |
}
|
|
|
|
| 516 |
nullptr;
|
| 517 |
}
|
| 518 |
|
| 519 |
+
template<int D, int ncols1, int ncols2> // D == head size
|
| 520 |
__launch_bounds__(D, 1)
|
| 521 |
static __global__ void flash_attn_stream_k_fixup(
|
| 522 |
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) {
|
|
|
|
| 665 |
fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for all combinations of q4_0, q4_1, q5_0, q5_1, q8_0, and f16.\n");
|
| 666 |
GGML_ABORT("fatal error");
|
| 667 |
} else {
|
| 668 |
+
fprintf(stderr, "Unsupported KV type combination for head_size %d.\n", D);
|
| 669 |
fprintf(stderr, "Only f16 is supported.\n");
|
| 670 |
GGML_ABORT("fatal error");
|
| 671 |
}
|
| 672 |
}
|
| 673 |
|
| 674 |
+
template <int DV, int ncols1, int ncols2>
|
| 675 |
void launch_fattn(
|
| 676 |
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared,
|
| 677 |
const int KQ_row_granularity, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE
|
|
|
|
| 691 |
|
| 692 |
GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
|
| 693 |
GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
|
| 694 |
+
"the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
|
| 695 |
|
| 696 |
GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
|
| 697 |
|
|
|
|
| 754 |
const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];
|
| 755 |
|
| 756 |
const dim3 block_dim(warp_size, nwarps, 1);
|
| 757 |
+
int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy.
|
| 758 |
+
CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared));
|
| 759 |
+
|
| 760 |
dim3 blocks_num;
|
| 761 |
if (stream_k) {
|
| 762 |
// For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
|
| 763 |
+
const int max_blocks = max_blocks_per_sm*nsm;
|
| 764 |
const int tiles_nwaves = (ntiles_total + max_blocks - 1) / max_blocks;
|
| 765 |
const int tiles_efficiency_percent = 100 * ntiles_total / (max_blocks*tiles_nwaves);
|
| 766 |
|
|
|
|
| 772 |
blocks_num.y = 1;
|
| 773 |
blocks_num.z = 1;
|
| 774 |
|
| 775 |
+
dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + DV) * sizeof(float));
|
| 776 |
} else {
|
| 777 |
GGML_ASSERT(K->ne[1] % KQ_row_granularity == 0);
|
| 778 |
const int ntiles_KQ = K->ne[1] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size.
|
| 779 |
|
|
|
|
|
|
|
|
|
|
| 780 |
// parallel_blocks should be at least large enough to achieve max. occupancy for a single wave:
|
| 781 |
parallel_blocks = std::max((nsm * max_blocks_per_sm) / ntiles_total, 1);
|
| 782 |
|
|
|
|
| 853 |
|
| 854 |
if (stream_k) {
|
| 855 |
if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
|
| 856 |
+
const dim3 block_dim_combine(DV, 1, 1);
|
| 857 |
const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2};
|
| 858 |
|
| 859 |
+
flash_attn_stream_k_fixup<DV, ncols1, ncols2>
|
| 860 |
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
|
| 861 |
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
|
| 862 |
}
|
| 863 |
} else if (parallel_blocks > 1) {
|
| 864 |
+
const dim3 block_dim_combine(DV, 1, 1);
|
| 865 |
const dim3 blocks_num_combine(Q->ne[1], 1, blocks_num.z);
|
| 866 |
const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2);
|
| 867 |
|
| 868 |
+
flash_attn_combine_results<DV>
|
| 869 |
<<<blocks_num_combine, block_dim_combine, nbytes_shared_combine, main_stream>>>
|
| 870 |
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data, parallel_blocks);
|
| 871 |
}
|
ggml/src/ggml-cuda/fattn-mma-f16.cuh
CHANGED
|
@@ -13,104 +13,217 @@ typedef tile<16, 16, float> tile_C_KQ_16;
|
|
| 13 |
typedef tile<16, 4, half2> tile_C_VKQ;
|
| 14 |
typedef tile<16, 8, half2> tile_C_VKQ_16;
|
| 15 |
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
|
| 18 |
-
const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int stride_KV) {
|
| 19 |
-
constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts.
|
| 20 |
|
| 21 |
-
//
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
-
constexpr int preload = 64;
|
| 29 |
-
constexpr int h2_per_chunk = 16/sizeof(half2);
|
| 30 |
-
constexpr int chunks_per_row = k0_sync_start / h2_per_chunk;
|
| 31 |
-
constexpr int stride_i = WARP_SIZE / chunks_per_row;
|
| 32 |
#pragma unroll
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
const int k = (chunks_per_row == WARP_SIZE ? threadIdx.x : threadIdx.x % chunks_per_row)*h2_per_chunk;
|
| 36 |
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
constexpr int k0_sync_start = 0;
|
| 41 |
-
#endif // CP_ASYNC_AVAILABLE
|
| 42 |
-
static_assert(k0_sync_start % WARP_SIZE == 0, "bad k0_sync_start");
|
| 43 |
|
| 44 |
-
// If D is not a power of 2, the rest is loaded synchronously.
|
| 45 |
-
// K/V data is loaded with decreasing granularity for D for better memory bandwidth.
|
| 46 |
-
static_assert(KQ_per_iter % (4*nwarps) == 0, "out of bounds");
|
| 47 |
#pragma unroll
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
const int k0_stop = D/2 - (D/2) % (1*stride_k);
|
| 51 |
-
const int stride_i = WARP_SIZE / stride_k;
|
| 52 |
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
#pragma unroll
|
| 58 |
-
|
| 59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
#pragma unroll
|
| 62 |
-
|
| 63 |
-
|
| 64 |
|
| 65 |
-
|
|
|
|
| 66 |
}
|
| 67 |
-
}
|
|
|
|
| 68 |
}
|
| 69 |
}
|
| 70 |
|
| 71 |
-
template<int ncols1, int nwarps, int
|
| 72 |
static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
|
| 73 |
const half2 * const __restrict__ mask_h2, half2 * const __restrict__ tile_mask, const int stride_mask) {
|
| 74 |
-
static_assert(
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
|
|
|
| 79 |
|
| 80 |
-
|
| 81 |
|
| 82 |
#pragma unroll
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
|
| 91 |
-
|
| 92 |
|
| 93 |
-
|
|
|
|
|
|
|
| 94 |
}
|
| 95 |
-
|
| 96 |
-
constexpr int cols_per_warp = 2*WARP_SIZE/
|
| 97 |
constexpr int stride_j = nwarps * cols_per_warp;
|
| 98 |
#pragma unroll
|
| 99 |
for (int j0 = 0; j0 < ncols1; j0 += stride_j) {
|
| 100 |
-
const int j = j0 + threadIdx.y*cols_per_warp + (
|
| 101 |
|
| 102 |
if (j0 + stride_j > ncols1 && j >= ncols1) {
|
| 103 |
break;
|
| 104 |
}
|
| 105 |
|
| 106 |
-
const int i =
|
| 107 |
|
| 108 |
-
tile_mask[j*(
|
| 109 |
}
|
| 110 |
-
#endif // CP_ASYNC_AVAILABLE
|
| 111 |
}
|
| 112 |
|
| 113 |
-
template<int
|
| 114 |
static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
| 115 |
const float2 * const __restrict__ Q_f2,
|
| 116 |
const half2 * const __restrict__ K_h2,
|
|
@@ -123,9 +236,11 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
| 123 |
const float logit_softcap,
|
| 124 |
const int ne01,
|
| 125 |
const int ne02,
|
| 126 |
-
const int
|
|
|
|
| 127 |
const int stride_mask,
|
| 128 |
const int jt,
|
|
|
|
| 129 |
half2 * const __restrict__ tile_K,
|
| 130 |
half2 * const __restrict__ tile_V,
|
| 131 |
half2 * const __restrict__ tile_mask,
|
|
@@ -135,59 +250,107 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
| 135 |
float * const __restrict__ KQ_rowsum,
|
| 136 |
const int kb0) {
|
| 137 |
#ifdef NEW_MMA_AVAILABLE
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
constexpr int cols_per_warp = ntiles * tile_B::I;
|
| 139 |
constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles;
|
| 140 |
constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
|
| 141 |
-
constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts.
|
| 142 |
|
| 143 |
-
|
| 144 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
|
| 146 |
// Use wide variants of tiles if ntiles >= 2.
|
| 147 |
tile_B_16 * Q_B_16 = (tile_B_16 *) Q_B;
|
| 148 |
tile_C_VKQ_16 * VKQ_C_16 = (tile_C_VKQ_16 *) VKQ_C;
|
| 149 |
tile_C_KQ_16 * KQ_C_16 = (tile_C_KQ_16 *) KQ_C;
|
| 150 |
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
}
|
| 159 |
-
flash_attn_ext_f16_load_tile<D, nwarps, KQ_per_iter>(K_h2 + k_VKQ_0*stride_KV, tile_K, stride_KV);
|
| 160 |
-
__syncthreads();
|
| 161 |
-
#endif // CP_ASYNC_AVAILABLE
|
| 162 |
|
| 163 |
-
// Calculate tile of KQ:
|
| 164 |
#pragma unroll
|
| 165 |
-
for (int
|
| 166 |
-
const int
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
#pragma unroll
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
load_ldmatrix(K_A, tile_K + i_KQ_0*D2_padded + k_KQ_0, D2_padded);
|
| 171 |
-
if (ntiles == 1) {
|
| 172 |
-
mma(KQ_C[i_KQ_00/(np*tile_A::I)], K_A, Q_B[k_KQ_0/tile_A::J]);
|
| 173 |
-
} else {
|
| 174 |
#pragma unroll
|
| 175 |
-
for (int
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
// Wide version of KQ_C is column-major => swap A and B.
|
| 177 |
-
mma(KQ_C_16[i_KQ_00/(np*tile_A::I)
|
| 178 |
}
|
| 179 |
}
|
| 180 |
}
|
| 181 |
-
}
|
| 182 |
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
|
|
|
| 186 |
|
| 187 |
if (use_logit_softcap) {
|
| 188 |
-
static_assert(
|
| 189 |
#pragma unroll
|
| 190 |
-
for (int i = 0; i <
|
| 191 |
#pragma unroll
|
| 192 |
for (int l = 0; l < tile_C_KQ::ne; ++l) {
|
| 193 |
KQ_C[i].x[l] = logit_softcap*tanhf(KQ_C[i].x[l]);
|
|
@@ -205,7 +368,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
| 205 |
if (ntiles == 1) {
|
| 206 |
if (ncols2 > 1 || mask_h2) {
|
| 207 |
#pragma unroll
|
| 208 |
-
for (int i00 = 0; i00 <
|
| 209 |
const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ::I;
|
| 210 |
#pragma unroll
|
| 211 |
for (int l = 0; l < tile_C_KQ::ne; ++l) {
|
|
@@ -213,16 +376,16 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
| 213 |
const int j = ((threadIdx.y / np)*tile_C_KQ::J + tile_C_KQ::get_j(l)) / ncols2;
|
| 214 |
|
| 215 |
KQ_C[i00/(np*tile_C_KQ::I)].x[l] += slope *
|
| 216 |
-
__half2float(((const half *) tile_mask)[j*(
|
| 217 |
}
|
| 218 |
}
|
| 219 |
}
|
| 220 |
|
| 221 |
// Calculate softmax for each KQ column using the current max. value.
|
| 222 |
// The divisor is stored in KQ_rowsum and will be applied at the end.
|
| 223 |
-
static_assert(
|
| 224 |
#pragma unroll
|
| 225 |
-
for (int k = 0; k <
|
| 226 |
#pragma unroll
|
| 227 |
for (int l = 0; l < tile_C_KQ::ne; ++l) {
|
| 228 |
KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k].x[l]);
|
|
@@ -238,10 +401,9 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
| 238 |
}
|
| 239 |
}
|
| 240 |
|
| 241 |
-
static_assert(
|
| 242 |
-
|
| 243 |
#pragma unroll
|
| 244 |
-
for (int k = 0; k <
|
| 245 |
#pragma unroll
|
| 246 |
for (int l = 0; l < tile_C_KQ::ne; ++l) {
|
| 247 |
KQ_C[k].x[l] = expf(KQ_C[k].x[l] - KQ_max_new[l % 2]);
|
|
@@ -252,7 +414,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
| 252 |
} else { // ntiles > 1
|
| 253 |
if (ncols2 > 1 || mask_h2) {
|
| 254 |
#pragma unroll
|
| 255 |
-
for (int i00 = 0; i00 <
|
| 256 |
const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ_16::J;
|
| 257 |
#pragma unroll
|
| 258 |
for (int t = 0; t < ntiles/2; ++t) {
|
|
@@ -261,7 +423,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
| 261 |
const int i = (i0 + tile_C_KQ_16::get_j(l0)) / 2;
|
| 262 |
const int j = ((threadIdx.y / np)*cols_per_warp + t*tile_C_KQ_16::I + tile_C_KQ_16::get_i(l0)) / ncols2;
|
| 263 |
|
| 264 |
-
const float2 tmp = __half22float2(tile_mask[j*(
|
| 265 |
const int KQ_index = i00/(np*tile_C_KQ_16::J) * ntiles/2 + t;
|
| 266 |
KQ_C_16[KQ_index].x[l0 + 0] += slope*tmp.x;
|
| 267 |
KQ_C_16[KQ_index].x[l0 + 1] += slope*tmp.y;
|
|
@@ -272,9 +434,9 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
| 272 |
|
| 273 |
// Calculate softmax for each KQ column using the current max. value.
|
| 274 |
// The divisor is stored in KQ_rowsum and will be applied at the end.
|
| 275 |
-
static_assert(
|
| 276 |
#pragma unroll
|
| 277 |
-
for (int k = 0; k <
|
| 278 |
#pragma unroll
|
| 279 |
for (int t = 0; t < ntiles/2; ++t) {
|
| 280 |
#pragma unroll
|
|
@@ -294,9 +456,9 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
| 294 |
}
|
| 295 |
}
|
| 296 |
|
| 297 |
-
static_assert(
|
| 298 |
#pragma unroll
|
| 299 |
-
for (int k = 0; k <
|
| 300 |
#pragma unroll
|
| 301 |
for (int t = 0; t < ntiles/2; ++t) {
|
| 302 |
#pragma unroll
|
|
@@ -325,7 +487,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
| 325 |
if (ntiles == 1) {
|
| 326 |
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]);
|
| 327 |
#pragma unroll
|
| 328 |
-
for (int i = 0; i <
|
| 329 |
#pragma unroll
|
| 330 |
for (int l = 0; l < tile_C_VKQ::ne; ++l) {
|
| 331 |
VKQ_C[i].x[l] *= KQ_max_scale_h2;
|
|
@@ -336,7 +498,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
| 336 |
for (int col = 0; col < cols_per_thread; ++col) {
|
| 337 |
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]);
|
| 338 |
#pragma unroll
|
| 339 |
-
for (int i = 0; i <
|
| 340 |
#pragma unroll
|
| 341 |
for (int l0 = 0; l0 < tile_C_VKQ_16::ne; l0 += 2) {
|
| 342 |
VKQ_C_16[i*ntiles/2 + col/2].x[l0 + col % 2] *= KQ_max_scale_h2;
|
|
@@ -347,16 +509,16 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
| 347 |
}
|
| 348 |
|
| 349 |
// Convert KQ C tiles into B tiles for VKQ calculation:
|
| 350 |
-
tile_B B[
|
| 351 |
tile_B_16 * B_16 = (tile_B_16 *) B;
|
| 352 |
-
static_assert(
|
| 353 |
if (ntiles == 1) {
|
| 354 |
#pragma unroll
|
| 355 |
-
for (int k = 0; k <
|
| 356 |
B[k] = get_transposed(get_half2(KQ_C[k]));
|
| 357 |
}
|
| 358 |
} else {
|
| 359 |
-
for (int k = 0; k <
|
| 360 |
#pragma unroll
|
| 361 |
for (int t = 0; t < ntiles/2; ++t) {
|
| 362 |
B_16[k*ntiles/2 + t] = get_half2(KQ_C_16[k*ntiles/2 + t]);
|
|
@@ -364,52 +526,67 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
| 364 |
}
|
| 365 |
}
|
| 366 |
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
if (
|
| 373 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 374 |
}
|
| 375 |
-
flash_attn_ext_f16_load_tile<D, nwarps, KQ_per_iter>(K_h2 + (k_VKQ_0 + KQ_per_iter)*stride_KV, tile_K, stride_KV);
|
| 376 |
}
|
| 377 |
-
#else
|
| 378 |
-
flash_attn_ext_f16_load_tile<D, nwarps, KQ_per_iter>(V_h2 + k_VKQ_0*stride_KV, tile_V, stride_KV);
|
| 379 |
-
__syncthreads();
|
| 380 |
-
#endif // CP_ASYNC_AVAILABLE
|
| 381 |
|
| 382 |
-
// Calculate VKQ tile:
|
| 383 |
#pragma unroll
|
| 384 |
-
for (int
|
| 385 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 386 |
#pragma unroll
|
| 387 |
-
|
| 388 |
-
|
| 389 |
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
#pragma unroll
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
|
|
|
| 399 |
}
|
| 400 |
}
|
| 401 |
}
|
| 402 |
-
}
|
| 403 |
-
|
| 404 |
-
#ifndef CP_ASYNC_AVAILABLE
|
| 405 |
-
__syncthreads(); // Only needed if tile_K == tile_V.
|
| 406 |
-
#endif // CP_ASYNC_AVAILABLE
|
| 407 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 408 |
#else
|
| 409 |
GGML_UNUSED(Q_f2); GGML_UNUSED(K_h2); GGML_UNUSED(V_h2);
|
| 410 |
GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup);
|
| 411 |
GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap);
|
| 412 |
-
GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(
|
| 413 |
GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K);
|
| 414 |
GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K);
|
| 415 |
GGML_UNUSED(tile_V); GGML_UNUSED(tile_mask); GGML_UNUSED(Q_B);
|
|
@@ -419,7 +596,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
| 419 |
#endif // NEW_MMA_AVAILABLE
|
| 420 |
}
|
| 421 |
|
| 422 |
-
template<int
|
| 423 |
static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
| 424 |
const float2 * const __restrict__ Q_f2,
|
| 425 |
const half2 * const __restrict__ K_h2,
|
|
@@ -434,7 +611,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
| 434 |
const int ne02,
|
| 435 |
const int stride_Q1,
|
| 436 |
const int stride_Q2,
|
| 437 |
-
const int
|
|
|
|
| 438 |
const int stride_mask,
|
| 439 |
const int jt,
|
| 440 |
const int kb0_start,
|
|
@@ -442,6 +620,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
| 442 |
#ifdef NEW_MMA_AVAILABLE
|
| 443 |
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
| 444 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 445 |
constexpr int ncols = ncols1 * ncols2;
|
| 446 |
constexpr int cols_per_warp = ntiles * tile_B::I;
|
| 447 |
constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles;
|
|
@@ -449,22 +635,19 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
| 449 |
|
| 450 |
static_assert(nwarps * (cols_per_warp/ncols2) % ncols1 == 0, "bad nwarps");
|
| 451 |
|
| 452 |
-
|
| 453 |
-
|
|
|
|
| 454 |
|
| 455 |
-
constexpr int
|
| 456 |
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
half2 *
|
| 461 |
-
#else
|
| 462 |
-
half2 * tile_V = tile_K;
|
| 463 |
-
#endif // CP_ASYNC_AVAILABLE
|
| 464 |
-
half2 * tile_mask = tile_V + KQ_per_iter*D2_padded;
|
| 465 |
|
| 466 |
-
tile_B Q_B[
|
| 467 |
-
tile_C_VKQ VKQ_C[
|
| 468 |
|
| 469 |
tile_B_16 * Q_B_16 = (tile_B_16 *) Q_B;
|
| 470 |
tile_C_VKQ_16 * VKQ_C_16 = (tile_C_VKQ_16 *) VKQ_C;
|
|
@@ -476,13 +659,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
| 476 |
KQ_max[col] = -FLT_MAX/2.0f;
|
| 477 |
}
|
| 478 |
|
| 479 |
-
//
|
|
|
|
| 480 |
// The loading is done with decreasing granularity for D for better memory bandwidth.
|
| 481 |
const half2 scale_h2 = make_half2(scale, scale);
|
| 482 |
#pragma unroll
|
| 483 |
for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
|
| 484 |
-
const int k0_start = stride_k == WARP_SIZE ? 0 :
|
| 485 |
-
const int k0_stop =
|
| 486 |
const int stride_jc = WARP_SIZE / stride_k;
|
| 487 |
|
| 488 |
if (k0_start == k0_stop) {
|
|
@@ -506,14 +690,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
| 506 |
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
|
| 507 |
|
| 508 |
const float2 tmp = Q_f2[(jt*ncols1 + j)*stride_Q1 + c*stride_Q2 + k];
|
| 509 |
-
|
| 510 |
}
|
| 511 |
} else {
|
| 512 |
#pragma unroll
|
| 513 |
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
|
| 514 |
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
|
| 515 |
|
| 516 |
-
|
| 517 |
}
|
| 518 |
}
|
| 519 |
}
|
|
@@ -521,18 +705,18 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
| 521 |
|
| 522 |
__syncthreads();
|
| 523 |
|
| 524 |
-
{
|
| 525 |
const int j0 = (threadIdx.y / np) * cols_per_warp;
|
| 526 |
|
| 527 |
#pragma unroll
|
| 528 |
-
for (int k0 = 0; k0 <
|
| 529 |
if (ntiles == 1) {
|
| 530 |
-
load_ldmatrix(Q_B[k0/tile_B::J],
|
| 531 |
} else {
|
| 532 |
#pragma unroll
|
| 533 |
for (int t = 0; t < ntiles/2; ++t) {
|
| 534 |
load_ldmatrix(Q_B_16[k0/tile_B_16::J * ntiles/2 + t],
|
| 535 |
-
|
| 536 |
}
|
| 537 |
}
|
| 538 |
}
|
|
@@ -540,35 +724,37 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
| 540 |
|
| 541 |
__syncthreads();
|
| 542 |
|
| 543 |
-
// Preload mask and K data for first iteration when using cp_async:
|
| 544 |
-
|
| 545 |
-
|
| 546 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 547 |
}
|
| 548 |
-
flash_attn_ext_f16_load_tile<D, nwarps, KQ_per_iter>(K_h2 + kb0_start*KQ_per_iter*stride_KV, tile_K, stride_KV);
|
| 549 |
-
#endif // CP_ASYNC_AVAILABLE
|
| 550 |
|
| 551 |
// Iterate over ne11 == previous tokens:
|
| 552 |
for (int kb0 = kb0_start; kb0 < kb0_stop-1; ++kb0) {
|
| 553 |
constexpr bool last_iter = false;
|
| 554 |
-
flash_attn_ext_f16_iter<
|
| 555 |
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
|
| 556 |
-
ne01, ne02,
|
| 557 |
}
|
| 558 |
{ // kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
|
| 559 |
constexpr bool last_iter = true;
|
| 560 |
-
flash_attn_ext_f16_iter<
|
| 561 |
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
|
| 562 |
-
ne01, ne02,
|
| 563 |
}
|
| 564 |
|
| 565 |
-
// With
|
| 566 |
// there can be a race condition on shared memory access for combining/writing back results.
|
| 567 |
-
|
| 568 |
-
if (nwarps*cols_per_warp > KQ_per_iter) {
|
| 569 |
__syncthreads();
|
| 570 |
}
|
| 571 |
-
#endif // CP_ASYNC_AVAILABLE
|
| 572 |
|
| 573 |
// Finally, sum up partial KQ rowsums.
|
| 574 |
// The partial sums are spread across 8/4 threads each, does not need full reduce.
|
|
@@ -584,38 +770,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
| 584 |
}
|
| 585 |
}
|
| 586 |
|
| 587 |
-
//
|
| 588 |
-
// It's faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
|
| 589 |
-
//
|
| 590 |
-
if (ntiles == 1) {
|
| 591 |
-
const int jc_cwd = threadIdx.y*tile_B::I + tile_B::get_i(-1); // jc combine write data
|
| 592 |
-
#pragma unroll
|
| 593 |
-
for (int k0 = 0; k0 < D/2; k0 += tile_B::J) {
|
| 594 |
-
const tile_B B = get_transposed(VKQ_C[k0/tile_B::J]); // Conversion of C to B matrix puts it in column-major format.
|
| 595 |
|
| 596 |
-
|
| 597 |
-
|
| 598 |
-
|
| 599 |
-
|
| 600 |
-
tile_K[jc_cwd*D2_padded + k] = B.x[l];
|
| 601 |
-
}
|
| 602 |
-
}
|
| 603 |
-
} else {
|
| 604 |
-
#pragma unroll
|
| 605 |
-
for (int t = 0; t < ntiles/2; ++t) {
|
| 606 |
-
const int j0 = threadIdx.y*cols_per_warp + t*tile_C_VKQ_16::I;
|
| 607 |
-
#pragma unroll
|
| 608 |
-
for (int k0 = 0; k0 < D/2; k0 += tile_C_VKQ_16::J) {
|
| 609 |
-
#pragma unroll
|
| 610 |
-
for (int l = 0; l < tile_C_VKQ_16::ne; ++l) {
|
| 611 |
-
const int j = j0 + tile_C_VKQ_16::get_i(l);
|
| 612 |
-
const int k = k0 + tile_C_VKQ_16::get_j(l);
|
| 613 |
-
|
| 614 |
-
tile_K[j*D2_padded + k] = VKQ_C_16[k0/tile_C_VKQ_16::J * ntiles/2 + t].x[l];
|
| 615 |
-
}
|
| 616 |
-
}
|
| 617 |
-
}
|
| 618 |
-
}
|
| 619 |
|
| 620 |
if constexpr (ntiles == 1) {
|
| 621 |
const int jc_cwmo = (threadIdx.x % (2*tile_C_VKQ::J)) / tile_C_VKQ::J; // jc combine write meta offset
|
|
@@ -624,7 +785,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
| 624 |
|
| 625 |
if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*tile_C_VKQ::J) {
|
| 626 |
// Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale.
|
| 627 |
-
((float2 *)
|
| 628 |
}
|
| 629 |
|
| 630 |
__syncthreads();
|
|
@@ -649,7 +810,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
| 649 |
|
| 650 |
if (((!needs_fixup && !is_fixup) || np > 1) && (ntiles == 4 || threadIdx.x % 4 < cols_per_thread)) {
|
| 651 |
// Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale.
|
| 652 |
-
((float2 *)
|
| 653 |
}
|
| 654 |
|
| 655 |
__syncthreads();
|
|
@@ -676,11 +837,11 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
| 676 |
constexpr int nmeta = np*cols_per_warp >= WARP_SIZE ? np*cols_per_warp/WARP_SIZE : 1;
|
| 677 |
|
| 678 |
const int jc_meta = threadIdx.y*cols_per_warp + (np*cols_per_warp < WARP_SIZE ? threadIdx.x % (np*cols_per_warp) : threadIdx.x);
|
| 679 |
-
float2 * const meta_ptr = ((float2 *)
|
| 680 |
float2 meta[nmeta];
|
| 681 |
#pragma unroll
|
| 682 |
for (int imeta = 0; imeta < nmeta; ++imeta) {
|
| 683 |
-
meta[imeta] = meta_ptr[imeta * WARP_SIZE *
|
| 684 |
}
|
| 685 |
|
| 686 |
float KQ_cmn = meta[0].x; // KQ combine max new, max between all parallel warps.
|
|
@@ -690,10 +851,9 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
| 690 |
}
|
| 691 |
#pragma unroll
|
| 692 |
for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) {
|
| 693 |
-
if (offset
|
| 694 |
-
|
| 695 |
}
|
| 696 |
-
KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE));
|
| 697 |
}
|
| 698 |
|
| 699 |
float KQ_cms[nmeta]; // KQ combine max scale per warp.
|
|
@@ -709,10 +869,9 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
| 709 |
}
|
| 710 |
#pragma unroll
|
| 711 |
for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) {
|
| 712 |
-
if (offset
|
| 713 |
-
|
| 714 |
}
|
| 715 |
-
KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE);
|
| 716 |
}
|
| 717 |
|
| 718 |
// Write back combined meta data:
|
|
@@ -720,7 +879,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
| 720 |
for (int imeta = 0; imeta < nmeta; ++imeta) {
|
| 721 |
if (np*cols_per_warp >= WARP_SIZE || threadIdx.x < np*cols_per_warp) {
|
| 722 |
// Combined KQ max scale + rowsum.
|
| 723 |
-
meta_ptr[imeta * WARP_SIZE *
|
| 724 |
}
|
| 725 |
}
|
| 726 |
|
|
@@ -736,88 +895,118 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
| 736 |
}
|
| 737 |
}
|
| 738 |
|
| 739 |
-
|
| 740 |
-
|
| 741 |
-
|
| 742 |
-
|
| 743 |
-
|
| 744 |
-
|
| 745 |
-
|
| 746 |
-
float2 * dstk_fixup_data = dstk_fixup + gridDim.x*(2*ncols) + blockIdx.x*(ncols*(D/2));
|
| 747 |
|
| 748 |
#pragma unroll
|
| 749 |
-
|
| 750 |
-
|
| 751 |
-
const int k0_stop = D/2 - (D/2) % (1*stride_k);
|
| 752 |
-
const int stride_jc = WARP_SIZE / stride_k;
|
| 753 |
|
| 754 |
-
|
| 755 |
-
|
| 756 |
}
|
| 757 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 758 |
#pragma unroll
|
| 759 |
-
|
| 760 |
-
|
|
|
|
| 761 |
|
| 762 |
-
|
| 763 |
-
|
| 764 |
}
|
|
|
|
|
|
|
| 765 |
|
| 766 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 767 |
|
| 768 |
-
|
| 769 |
-
|
|
|
|
|
|
|
|
|
|
| 770 |
|
| 771 |
-
if (
|
| 772 |
continue;
|
| 773 |
}
|
| 774 |
|
| 775 |
-
const float * meta_j = (const float *) tile_K + jc_tile_K*D2_padded + D/2;
|
| 776 |
#pragma unroll
|
| 777 |
-
for (int
|
| 778 |
-
const int
|
| 779 |
|
| 780 |
-
|
| 781 |
-
|
| 782 |
-
for (int ip = 0; ip < np; ++ip) {
|
| 783 |
-
const float KQ_crs = np == 1 ? 1.0f : meta_j[ip*cols_per_warp * D2_padded + 0];
|
| 784 |
-
const float2 dstk_val_add = __half22float2(tile_K[(jc_tile_K + ip*cols_per_warp) * D2_padded + k]);
|
| 785 |
-
dstk_val.x += dstk_val_add.x*KQ_crs;
|
| 786 |
-
dstk_val.y += dstk_val_add.y*KQ_crs;
|
| 787 |
}
|
| 788 |
|
| 789 |
-
|
| 790 |
-
|
| 791 |
-
|
| 792 |
-
|
|
|
|
|
|
|
|
|
|
| 793 |
}
|
| 794 |
|
| 795 |
-
|
| 796 |
-
|
| 797 |
-
|
| 798 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 799 |
}
|
| 800 |
}
|
| 801 |
}
|
| 802 |
}
|
| 803 |
-
|
| 804 |
-
|
| 805 |
-
|
| 806 |
-
__syncthreads();
|
| 807 |
}
|
| 808 |
#else
|
| 809 |
GGML_UNUSED(Q_f2); GGML_UNUSED(K_h2); GGML_UNUSED(V_h2);
|
| 810 |
GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup);
|
| 811 |
GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap);
|
| 812 |
GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_Q1);
|
| 813 |
-
GGML_UNUSED(stride_Q2); GGML_UNUSED(
|
| 814 |
GGML_UNUSED(jt); GGML_UNUSED(kb0_start); GGML_UNUSED(kb0_stop);
|
| 815 |
NO_DEVICE_CODE;
|
| 816 |
#endif // NEW_MMA_AVAILABLE
|
| 817 |
}
|
| 818 |
|
| 819 |
-
template<int
|
| 820 |
-
__launch_bounds__(nwarps*WARP_SIZE,
|
| 821 |
static __global__ void flash_attn_ext_f16(
|
| 822 |
const char * __restrict__ Q,
|
| 823 |
const char * __restrict__ K,
|
|
@@ -857,24 +1046,27 @@ static __global__ void flash_attn_ext_f16(
|
|
| 857 |
#if defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
|
| 858 |
|
| 859 |
// Skip unused kernel variants for faster compilation:
|
| 860 |
-
if (use_logit_softcap && !(
|
| 861 |
NO_DEVICE_CODE;
|
| 862 |
return;
|
| 863 |
}
|
| 864 |
|
| 865 |
-
|
|
|
|
|
|
|
| 866 |
|
| 867 |
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
| 868 |
|
| 869 |
const int stride_Q1 = nb01 / sizeof(float2);
|
| 870 |
const int stride_Q2 = nb02 / sizeof(float2);
|
| 871 |
-
const int
|
|
|
|
| 872 |
const int stride_mask = nb31 / sizeof(half2);
|
| 873 |
|
| 874 |
const int iter_k = ne11 / FATTN_KQ_STRIDE;
|
| 875 |
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
|
| 876 |
|
| 877 |
-
constexpr int kb_niter = FATTN_KQ_STRIDE /
|
| 878 |
|
| 879 |
// kbc == k block continuous, current index in continuous ijk space.
|
| 880 |
int kbc = (blockIdx.x + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
|
|
@@ -893,9 +1085,9 @@ static __global__ void flash_attn_ext_f16(
|
|
| 893 |
|
| 894 |
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
|
| 895 |
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
|
| 896 |
-
const half2 * V_h2 = (const half2 *) (V +
|
| 897 |
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
|
| 898 |
-
float2 * dstk = ((float2 *) dst) + channel*(ncols2 *
|
| 899 |
|
| 900 |
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
|
| 901 |
|
|
@@ -905,14 +1097,14 @@ static __global__ void flash_attn_ext_f16(
|
|
| 905 |
constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
|
| 906 |
if (kb0_start == 0) {
|
| 907 |
constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
|
| 908 |
-
flash_attn_ext_f16_process_tile<
|
| 909 |
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
|
| 910 |
-
ne01, ne02, stride_Q1, stride_Q2,
|
| 911 |
} else {
|
| 912 |
constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
|
| 913 |
-
flash_attn_ext_f16_process_tile<
|
| 914 |
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
|
| 915 |
-
ne01, ne02, stride_Q1, stride_Q2,
|
| 916 |
}
|
| 917 |
|
| 918 |
kbc += iter_k;
|
|
@@ -931,9 +1123,9 @@ static __global__ void flash_attn_ext_f16(
|
|
| 931 |
|
| 932 |
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
|
| 933 |
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
|
| 934 |
-
const half2 * V_h2 = (const half2 *) (V +
|
| 935 |
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
|
| 936 |
-
float2 * dstk = ((float2 *) dst) + channel*(ncols2 *
|
| 937 |
|
| 938 |
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
|
| 939 |
|
|
@@ -942,9 +1134,9 @@ static __global__ void flash_attn_ext_f16(
|
|
| 942 |
|
| 943 |
constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
|
| 944 |
constexpr bool needs_fixup = false;
|
| 945 |
-
flash_attn_ext_f16_process_tile<
|
| 946 |
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
|
| 947 |
-
ne01, ne02, stride_Q1, stride_Q2,
|
| 948 |
#else
|
| 949 |
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
|
| 950 |
GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
|
|
@@ -960,28 +1152,42 @@ static __global__ void flash_attn_ext_f16(
|
|
| 960 |
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
|
| 961 |
}
|
| 962 |
|
| 963 |
-
template <int
|
| 964 |
void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 965 |
constexpr int ncols = ncols1 * ncols2;
|
| 966 |
-
constexpr int
|
| 967 |
-
constexpr int nwarps = (KQ_per_iter == 32 && ncols <= 16) ? 2 : 4;
|
| 968 |
-
constexpr int ntiles = ncols <= 8 ? 1 : (ncols <= 64 ? 2 : 4);
|
| 969 |
constexpr int cols_per_warp = ntiles * tile_B::I;
|
|
|
|
|
|
|
|
|
|
| 970 |
|
| 971 |
-
static_assert(
|
|
|
|
| 972 |
static_assert(ncols % cols_per_warp == 0, "bad ncols");
|
| 973 |
|
| 974 |
-
const
|
| 975 |
-
const
|
| 976 |
-
const
|
|
|
|
|
|
|
| 977 |
|
| 978 |
-
const
|
| 979 |
|
| 980 |
-
const size_t
|
| 981 |
-
|
| 982 |
-
|
| 983 |
-
|
| 984 |
-
const size_t nbytes_shared_total = std::max(nbytes_shared_KV + nbytes_shared_mask, nbytes_shared_combine);
|
| 985 |
|
| 986 |
float logit_softcap;
|
| 987 |
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
|
@@ -989,59 +1195,73 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
|
|
| 989 |
fattn_kernel_t fattn_kernel;
|
| 990 |
if (logit_softcap == 0.0f) {
|
| 991 |
constexpr bool use_logit_softcap = false;
|
| 992 |
-
fattn_kernel = flash_attn_ext_f16<
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 993 |
} else {
|
| 994 |
constexpr bool use_logit_softcap = true;
|
| 995 |
-
fattn_kernel = flash_attn_ext_f16<
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 996 |
}
|
| 997 |
|
| 998 |
-
launch_fattn<
|
| 999 |
(ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, FATTN_KQ_STRIDE, true, true, true);
|
| 1000 |
}
|
| 1001 |
|
| 1002 |
|
| 1003 |
-
#define DECL_FATTN_MMA_F16_CASE(
|
| 1004 |
-
template void ggml_cuda_flash_attn_ext_mma_f16_case
|
| 1005 |
-
<
|
| 1006 |
-
|
| 1007 |
-
#define DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(
|
| 1008 |
-
extern DECL_FATTN_MMA_F16_CASE(
|
| 1009 |
-
extern DECL_FATTN_MMA_F16_CASE(
|
| 1010 |
-
extern DECL_FATTN_MMA_F16_CASE(
|
| 1011 |
-
extern DECL_FATTN_MMA_F16_CASE(
|
| 1012 |
-
|
| 1013 |
-
|
| 1014 |
-
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(
|
| 1015 |
-
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(
|
| 1016 |
-
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(
|
| 1017 |
-
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(
|
| 1018 |
-
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(
|
| 1019 |
-
|
| 1020 |
-
|
| 1021 |
-
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(
|
| 1022 |
-
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(
|
| 1023 |
-
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(
|
| 1024 |
-
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(
|
| 1025 |
-
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(
|
| 1026 |
-
|
| 1027 |
-
|
| 1028 |
-
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(
|
| 1029 |
-
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(
|
| 1030 |
-
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(
|
| 1031 |
-
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(
|
| 1032 |
-
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(
|
| 1033 |
-
|
| 1034 |
-
|
| 1035 |
-
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(
|
| 1036 |
-
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(
|
| 1037 |
-
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(
|
| 1038 |
-
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(
|
| 1039 |
-
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(
|
| 1040 |
-
|
| 1041 |
-
|
| 1042 |
-
//
|
| 1043 |
-
|
| 1044 |
-
|
| 1045 |
-
|
| 1046 |
-
// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128)
|
| 1047 |
-
// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 128) // Needs too much shared memory.
|
|
|
|
| 13 |
typedef tile<16, 4, half2> tile_C_VKQ;
|
| 14 |
typedef tile<16, 8, half2> tile_C_VKQ_16;
|
| 15 |
|
| 16 |
+
// Config options for specific head sizes.
|
| 17 |
+
// Should not affect results, only speed/register pressure/shared memory use.
|
| 18 |
+
//
|
| 19 |
+
// nbatch_fa: number of KV rows per softmax rescaling of KQ rowsums and VKQ accumulators.
|
| 20 |
+
// nwarps_max: maximum number of warps per CUDA block, up to 8 warps in total can run per SM (given enough shared memory).
|
| 21 |
+
// Q_in_reg: whether the Q values should be kept permanently in registers.
|
| 22 |
+
// nstages_target: targeted number of pipeline stages for cp_async (if available), 0 means synchronous data loading.
|
| 23 |
+
// nbatch_K2: number of K half2 values in direction of DKQ to load in parallel.
|
| 24 |
+
// nbatch_V2: number of V half2 values in direction of DV to load in parallel.
|
| 25 |
+
// nbatch_combine: number of VKQ half2 values in direction of DV to combine in parallel.
|
| 26 |
+
|
| 27 |
+
template <int DKQ, int DV>
|
| 28 |
+
struct fattn_mma_f16_config;
|
| 29 |
+
|
| 30 |
+
template <>
|
| 31 |
+
struct fattn_mma_f16_config< 64, 64> {
|
| 32 |
+
static constexpr int nbatch_fa = 64;
|
| 33 |
+
static constexpr int nwarps_max = 4;
|
| 34 |
+
static constexpr bool Q_in_reg = true;
|
| 35 |
+
static constexpr int nstages_target = 2;
|
| 36 |
+
static constexpr int nbatch_K2 = 32;
|
| 37 |
+
static constexpr int nbatch_V2 = 32;
|
| 38 |
+
static constexpr int nbatch_combine = 32;
|
| 39 |
+
};
|
| 40 |
+
|
| 41 |
+
template <>
|
| 42 |
+
struct fattn_mma_f16_config< 80, 80> {
|
| 43 |
+
static constexpr int nbatch_fa = 64;
|
| 44 |
+
static constexpr int nwarps_max = 4;
|
| 45 |
+
static constexpr bool Q_in_reg = true;
|
| 46 |
+
static constexpr int nstages_target = 2;
|
| 47 |
+
static constexpr int nbatch_K2 = 40;
|
| 48 |
+
static constexpr int nbatch_V2 = 40;
|
| 49 |
+
static constexpr int nbatch_combine = 40;
|
| 50 |
+
};
|
| 51 |
+
|
| 52 |
+
template <>
|
| 53 |
+
struct fattn_mma_f16_config< 96, 96> {
|
| 54 |
+
static constexpr int nbatch_fa = 64;
|
| 55 |
+
static constexpr int nwarps_max = 4;
|
| 56 |
+
static constexpr bool Q_in_reg = true;
|
| 57 |
+
static constexpr int nstages_target = 2;
|
| 58 |
+
static constexpr int nbatch_K2 = 48;
|
| 59 |
+
static constexpr int nbatch_V2 = 48;
|
| 60 |
+
static constexpr int nbatch_combine = 48;
|
| 61 |
+
};
|
| 62 |
+
|
| 63 |
+
template <>
|
| 64 |
+
struct fattn_mma_f16_config<112, 112> {
|
| 65 |
+
static constexpr int nbatch_fa = 64;
|
| 66 |
+
static constexpr int nwarps_max = 4;
|
| 67 |
+
static constexpr bool Q_in_reg = true;
|
| 68 |
+
static constexpr int nstages_target = 2;
|
| 69 |
+
static constexpr int nbatch_K2 = 56;
|
| 70 |
+
static constexpr int nbatch_V2 = 56;
|
| 71 |
+
static constexpr int nbatch_combine = 56;
|
| 72 |
+
};
|
| 73 |
+
|
| 74 |
+
template <>
|
| 75 |
+
struct fattn_mma_f16_config<128, 128> {
|
| 76 |
+
static constexpr int nbatch_fa = 64;
|
| 77 |
+
static constexpr int nwarps_max = 4;
|
| 78 |
+
static constexpr bool Q_in_reg = true;
|
| 79 |
+
static constexpr int nstages_target = 2;
|
| 80 |
+
static constexpr int nbatch_K2 = 64;
|
| 81 |
+
static constexpr int nbatch_V2 = 64;
|
| 82 |
+
static constexpr int nbatch_combine = 64;
|
| 83 |
+
};
|
| 84 |
+
|
| 85 |
+
template <>
|
| 86 |
+
struct fattn_mma_f16_config<256, 256> {
|
| 87 |
+
static constexpr int nbatch_fa = 32;
|
| 88 |
+
static constexpr int nwarps_max = 4;
|
| 89 |
+
static constexpr bool Q_in_reg = true;
|
| 90 |
+
static constexpr int nstages_target = 2;
|
| 91 |
+
static constexpr int nbatch_K2 = 128;
|
| 92 |
+
static constexpr int nbatch_V2 = 128;
|
| 93 |
+
static constexpr int nbatch_combine = 128;
|
| 94 |
+
};
|
| 95 |
+
|
| 96 |
+
template <>
|
| 97 |
+
struct fattn_mma_f16_config<576, 512> {
|
| 98 |
+
static constexpr int nbatch_fa = 32;
|
| 99 |
+
static constexpr int nwarps_max = 8;
|
| 100 |
+
static constexpr bool Q_in_reg = false;
|
| 101 |
+
static constexpr int nstages_target = 1;
|
| 102 |
+
static constexpr int nbatch_K2 = 160;
|
| 103 |
+
static constexpr int nbatch_V2 = 128;
|
| 104 |
+
static constexpr int nbatch_combine = 128;
|
| 105 |
+
};
|
| 106 |
+
|
| 107 |
+
// ------------------------------------------------------------------------------------------------------------------
|
| 108 |
+
|
| 109 |
+
template<int stride_tile, int nwarps, int nbatch_fa, bool use_cp_async>
|
| 110 |
static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
|
| 111 |
+
const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV) {
|
|
|
|
| 112 |
|
| 113 |
+
// K/V data is loaded with decreasing granularity for D for better memory bandwidth.
|
| 114 |
+
// The minimum granularity with cp.async is 16 bytes, with synchronous data loading it's 4 bytes.
|
| 115 |
+
|
| 116 |
+
if (use_cp_async) {
|
| 117 |
+
constexpr int preload = 64;
|
| 118 |
+
constexpr int h2_per_chunk = 16/sizeof(half2);
|
| 119 |
+
const int chunks_per_row = D2 / h2_per_chunk;
|
| 120 |
|
| 121 |
+
const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared(tile_KV);
|
| 122 |
+
|
| 123 |
+
auto load = [&] __device__ (const int n) {
|
| 124 |
+
const int stride_k = WARP_SIZE >> n;
|
| 125 |
+
const int k0_start = stride_k == WARP_SIZE ? 0 : chunks_per_row - chunks_per_row % (2*stride_k);
|
| 126 |
+
const int k0_stop = chunks_per_row - chunks_per_row % (1*stride_k);
|
| 127 |
+
const int stride_i = WARP_SIZE / stride_k;
|
| 128 |
+
|
| 129 |
+
if (k0_start == k0_stop) {
|
| 130 |
+
return;
|
| 131 |
+
}
|
| 132 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
#pragma unroll
|
| 134 |
+
for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) {
|
| 135 |
+
const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
|
|
|
|
| 136 |
|
| 137 |
+
if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) {
|
| 138 |
+
break;
|
| 139 |
+
}
|
|
|
|
|
|
|
|
|
|
| 140 |
|
|
|
|
|
|
|
|
|
|
| 141 |
#pragma unroll
|
| 142 |
+
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
|
| 143 |
+
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
|
|
|
|
|
|
|
| 144 |
|
| 145 |
+
cp_async_cg_16<preload>(tile_KV_32 + i*(stride_tile*sizeof(half2)) + k*16, KV + i*stride_KV + k*h2_per_chunk);
|
| 146 |
+
}
|
| 147 |
+
}
|
| 148 |
+
};
|
| 149 |
+
ggml_cuda_unroll<5>{}(load);
|
| 150 |
+
} else {
|
| 151 |
+
static_assert(nbatch_fa % (4*nwarps) == 0, "out of bounds");
|
| 152 |
+
auto load = [&] __device__ (const int n) {
|
| 153 |
+
const int stride_k = WARP_SIZE >> n;
|
| 154 |
+
const int k0_start = stride_k == WARP_SIZE ? 0 : D2 - D2 % (2*stride_k);
|
| 155 |
+
const int k0_stop = D2 - D2 % (1*stride_k);
|
| 156 |
+
const int stride_i = WARP_SIZE / stride_k;
|
| 157 |
+
|
| 158 |
+
if (k0_start == k0_stop) {
|
| 159 |
+
return;
|
| 160 |
+
}
|
| 161 |
|
| 162 |
#pragma unroll
|
| 163 |
+
for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) {
|
| 164 |
+
const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
|
| 165 |
+
|
| 166 |
+
if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) {
|
| 167 |
+
break;
|
| 168 |
+
}
|
| 169 |
|
| 170 |
#pragma unroll
|
| 171 |
+
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
|
| 172 |
+
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
|
| 173 |
|
| 174 |
+
tile_KV[i*stride_tile + k] = KV[i*stride_KV + k];
|
| 175 |
+
}
|
| 176 |
}
|
| 177 |
+
};
|
| 178 |
+
ggml_cuda_unroll<3>{}(load);
|
| 179 |
}
|
| 180 |
}
|
| 181 |
|
| 182 |
+
template<int ncols1, int nwarps, int nbatch_fa, bool use_cp_async>
|
| 183 |
static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
|
| 184 |
const half2 * const __restrict__ mask_h2, half2 * const __restrict__ tile_mask, const int stride_mask) {
|
| 185 |
+
static_assert(nbatch_fa == 2*WARP_SIZE || WARP_SIZE % nbatch_fa == 0, "bad KQ_per_iter");
|
| 186 |
+
|
| 187 |
+
if (use_cp_async) {
|
| 188 |
+
constexpr int preload = nbatch_fa >= 32 ? nbatch_fa * sizeof(half) : 64;
|
| 189 |
+
constexpr int cols_per_warp = 8*WARP_SIZE/nbatch_fa;
|
| 190 |
+
constexpr int stride_j = nwarps * cols_per_warp;
|
| 191 |
|
| 192 |
+
const unsigned int tile_mask_32 = ggml_cuda_cvta_generic_to_shared(tile_mask);
|
| 193 |
|
| 194 |
#pragma unroll
|
| 195 |
+
for (int j0 = 0; j0 < ncols1; j0 += stride_j) {
|
| 196 |
+
const int j = j0 + threadIdx.y*cols_per_warp +
|
| 197 |
+
(nbatch_fa == 2*WARP_SIZE ? threadIdx.x / (WARP_SIZE/4) : threadIdx.x / (WARP_SIZE/cols_per_warp));
|
| 198 |
|
| 199 |
+
if (j0 + stride_j > ncols1 && j >= ncols1) {
|
| 200 |
+
break;
|
| 201 |
+
}
|
| 202 |
|
| 203 |
+
const int i = 4 * (threadIdx.x % (nbatch_fa/8));
|
| 204 |
|
| 205 |
+
cp_async_cg_16<preload>(tile_mask_32 + j*(nbatch_fa*sizeof(half) + 16) + i*sizeof(half2), mask_h2 + j*stride_mask + i);
|
| 206 |
+
}
|
| 207 |
+
return;
|
| 208 |
}
|
| 209 |
+
|
| 210 |
+
constexpr int cols_per_warp = 2*WARP_SIZE/nbatch_fa;
|
| 211 |
constexpr int stride_j = nwarps * cols_per_warp;
|
| 212 |
#pragma unroll
|
| 213 |
for (int j0 = 0; j0 < ncols1; j0 += stride_j) {
|
| 214 |
+
const int j = j0 + threadIdx.y*cols_per_warp + (nbatch_fa == 2*WARP_SIZE ? 0 : threadIdx.x / (WARP_SIZE/cols_per_warp));
|
| 215 |
|
| 216 |
if (j0 + stride_j > ncols1 && j >= ncols1) {
|
| 217 |
break;
|
| 218 |
}
|
| 219 |
|
| 220 |
+
const int i = nbatch_fa == 2*WARP_SIZE ? threadIdx.x : threadIdx.x % (WARP_SIZE/cols_per_warp);
|
| 221 |
|
| 222 |
+
tile_mask[j*(nbatch_fa/2 + 4) + i] = mask_h2[j*stride_mask + i];
|
| 223 |
}
|
|
|
|
| 224 |
}
|
| 225 |
|
| 226 |
+
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool needs_fixup, bool is_fixup, bool last_iter>
|
| 227 |
static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
| 228 |
const float2 * const __restrict__ Q_f2,
|
| 229 |
const half2 * const __restrict__ K_h2,
|
|
|
|
| 236 |
const float logit_softcap,
|
| 237 |
const int ne01,
|
| 238 |
const int ne02,
|
| 239 |
+
const int stride_K,
|
| 240 |
+
const int stride_V,
|
| 241 |
const int stride_mask,
|
| 242 |
const int jt,
|
| 243 |
+
half2 * const __restrict__ tile_Q,
|
| 244 |
half2 * const __restrict__ tile_K,
|
| 245 |
half2 * const __restrict__ tile_V,
|
| 246 |
half2 * const __restrict__ tile_mask,
|
|
|
|
| 250 |
float * const __restrict__ KQ_rowsum,
|
| 251 |
const int kb0) {
|
| 252 |
#ifdef NEW_MMA_AVAILABLE
|
| 253 |
+
typedef fattn_mma_f16_config<DKQ, DV> c;
|
| 254 |
+
|
| 255 |
+
#ifdef CP_ASYNC_AVAILABLE
|
| 256 |
+
constexpr int nstages = c::nstages_target;
|
| 257 |
+
#else
|
| 258 |
+
constexpr int nstages = 0;
|
| 259 |
+
#endif // CP_ASYNC_AVAILABLE
|
| 260 |
+
|
| 261 |
constexpr int cols_per_warp = ntiles * tile_B::I;
|
| 262 |
constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles;
|
| 263 |
constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
|
|
|
|
| 264 |
|
| 265 |
+
constexpr int stride_tile_Q = DKQ/2 + 4;
|
| 266 |
+
constexpr int stride_tile_K = c::nbatch_K2 + 4;
|
| 267 |
+
constexpr int stride_tile_V = c::nbatch_V2 + 4;
|
| 268 |
+
|
| 269 |
+
const int k_VKQ_0 = kb0 * c::nbatch_fa;
|
| 270 |
+
tile_C_KQ KQ_C[c::nbatch_fa/(np*tile_C_KQ::I) * ntiles];
|
| 271 |
|
| 272 |
// Use wide variants of tiles if ntiles >= 2.
|
| 273 |
tile_B_16 * Q_B_16 = (tile_B_16 *) Q_B;
|
| 274 |
tile_C_VKQ_16 * VKQ_C_16 = (tile_C_VKQ_16 *) VKQ_C;
|
| 275 |
tile_C_KQ_16 * KQ_C_16 = (tile_C_KQ_16 *) KQ_C;
|
| 276 |
|
| 277 |
+
if constexpr (nstages > 1) {
|
| 278 |
+
static_assert(c::nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading");
|
| 279 |
+
constexpr bool use_cp_async = true;
|
| 280 |
+
cp_async_wait_all();
|
| 281 |
+
__syncthreads();
|
| 282 |
+
flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
|
| 283 |
+
(V_h2 + k_VKQ_0*stride_V, tile_V, c::nbatch_V2, stride_V);
|
| 284 |
+
} else {
|
| 285 |
+
constexpr bool use_cp_async = nstages == 1;
|
| 286 |
+
if (ncols2 > 1 || mask_h2) {
|
| 287 |
+
flash_attn_ext_f16_load_mask<ncols1, nwarps, c::nbatch_fa, use_cp_async>(mask_h2 + k_VKQ_0/2, tile_mask, stride_mask);
|
| 288 |
+
}
|
| 289 |
}
|
|
|
|
|
|
|
|
|
|
| 290 |
|
|
|
|
| 291 |
#pragma unroll
|
| 292 |
+
for (int k0_start = 0; k0_start < DKQ/2; k0_start += c::nbatch_K2) {
|
| 293 |
+
const int k0_stop = k0_start + c::nbatch_K2 < DKQ/2 ? k0_start + c::nbatch_K2 : DKQ/2;
|
| 294 |
+
const int k0_diff = k0_stop - k0_start;
|
| 295 |
+
|
| 296 |
+
if (nstages <= 1) {
|
| 297 |
+
constexpr bool use_cp_async = nstages == 1;
|
| 298 |
+
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
|
| 299 |
+
(K_h2 + k_VKQ_0*stride_K + k0_start, tile_K, k0_diff, stride_K);
|
| 300 |
+
if (use_cp_async) {
|
| 301 |
+
cp_async_wait_all();
|
| 302 |
+
}
|
| 303 |
+
__syncthreads();
|
| 304 |
+
}
|
| 305 |
+
|
| 306 |
+
// Calculate tile of KQ:
|
| 307 |
+
if constexpr (c::Q_in_reg) {
|
| 308 |
#pragma unroll
|
| 309 |
+
for (int i_KQ_00 = 0; i_KQ_00 < c::nbatch_fa; i_KQ_00 += np*tile_A::I) {
|
| 310 |
+
const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*tile_A::I;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 311 |
#pragma unroll
|
| 312 |
+
for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += tile_A::J) {
|
| 313 |
+
tile_A K_A;
|
| 314 |
+
load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K);
|
| 315 |
+
if (ntiles == 1) {
|
| 316 |
+
mma(KQ_C[i_KQ_00/(np*tile_A::I)], K_A, Q_B[k_KQ_0/tile_A::J]);
|
| 317 |
+
} else {
|
| 318 |
+
#pragma unroll
|
| 319 |
+
for (int t = 0; t < ntiles/2; ++t) {
|
| 320 |
+
// Wide version of KQ_C is column-major => swap A and B.
|
| 321 |
+
mma(KQ_C_16[i_KQ_00/(np*tile_A::I) * ntiles/2 + t], Q_B_16[k_KQ_0/tile_A::J * ntiles/2 + t], K_A);
|
| 322 |
+
}
|
| 323 |
+
}
|
| 324 |
+
}
|
| 325 |
+
}
|
| 326 |
+
} else {
|
| 327 |
+
static_assert(ntiles == 2, "ntiles != 2 not implemented");
|
| 328 |
+
#pragma unroll
|
| 329 |
+
for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += tile_A::J) {
|
| 330 |
+
load_ldmatrix(Q_B_16[0], tile_Q + (threadIdx.y / np)*(tile_B_16::I*stride_tile_Q) + k_KQ_0, stride_tile_Q);
|
| 331 |
+
|
| 332 |
+
#pragma unroll
|
| 333 |
+
for (int i_KQ_00 = 0; i_KQ_00 < c::nbatch_fa; i_KQ_00 += np*tile_A::I) {
|
| 334 |
+
const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*tile_A::I;
|
| 335 |
+
|
| 336 |
+
tile_A K_A;
|
| 337 |
+
load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K);
|
| 338 |
+
|
| 339 |
// Wide version of KQ_C is column-major => swap A and B.
|
| 340 |
+
mma(KQ_C_16[i_KQ_00/(np*tile_A::I)], Q_B_16[0], K_A);
|
| 341 |
}
|
| 342 |
}
|
| 343 |
}
|
|
|
|
| 344 |
|
| 345 |
+
if (nstages <= 1) {
|
| 346 |
+
__syncthreads(); // Only needed if tile_K == tile_V.
|
| 347 |
+
}
|
| 348 |
+
}
|
| 349 |
|
| 350 |
if (use_logit_softcap) {
|
| 351 |
+
static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size");
|
| 352 |
#pragma unroll
|
| 353 |
+
for (int i = 0; i < c::nbatch_fa/(np*tile_C_KQ::I) * ntiles; ++i) {
|
| 354 |
#pragma unroll
|
| 355 |
for (int l = 0; l < tile_C_KQ::ne; ++l) {
|
| 356 |
KQ_C[i].x[l] = logit_softcap*tanhf(KQ_C[i].x[l]);
|
|
|
|
| 368 |
if (ntiles == 1) {
|
| 369 |
if (ncols2 > 1 || mask_h2) {
|
| 370 |
#pragma unroll
|
| 371 |
+
for (int i00 = 0; i00 < c::nbatch_fa; i00 += np*tile_C_KQ::I) {
|
| 372 |
const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ::I;
|
| 373 |
#pragma unroll
|
| 374 |
for (int l = 0; l < tile_C_KQ::ne; ++l) {
|
|
|
|
| 376 |
const int j = ((threadIdx.y / np)*tile_C_KQ::J + tile_C_KQ::get_j(l)) / ncols2;
|
| 377 |
|
| 378 |
KQ_C[i00/(np*tile_C_KQ::I)].x[l] += slope *
|
| 379 |
+
__half2float(((const half *) tile_mask)[j*(c::nbatch_fa + 8) + i]);
|
| 380 |
}
|
| 381 |
}
|
| 382 |
}
|
| 383 |
|
| 384 |
// Calculate softmax for each KQ column using the current max. value.
|
| 385 |
// The divisor is stored in KQ_rowsum and will be applied at the end.
|
| 386 |
+
static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size");
|
| 387 |
#pragma unroll
|
| 388 |
+
for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ::I); ++k) {
|
| 389 |
#pragma unroll
|
| 390 |
for (int l = 0; l < tile_C_KQ::ne; ++l) {
|
| 391 |
KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k].x[l]);
|
|
|
|
| 401 |
}
|
| 402 |
}
|
| 403 |
|
| 404 |
+
static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size");
|
|
|
|
| 405 |
#pragma unroll
|
| 406 |
+
for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ::I); ++k) {
|
| 407 |
#pragma unroll
|
| 408 |
for (int l = 0; l < tile_C_KQ::ne; ++l) {
|
| 409 |
KQ_C[k].x[l] = expf(KQ_C[k].x[l] - KQ_max_new[l % 2]);
|
|
|
|
| 414 |
} else { // ntiles > 1
|
| 415 |
if (ncols2 > 1 || mask_h2) {
|
| 416 |
#pragma unroll
|
| 417 |
+
for (int i00 = 0; i00 < c::nbatch_fa; i00 += np*tile_C_KQ_16::J) {
|
| 418 |
const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ_16::J;
|
| 419 |
#pragma unroll
|
| 420 |
for (int t = 0; t < ntiles/2; ++t) {
|
|
|
|
| 423 |
const int i = (i0 + tile_C_KQ_16::get_j(l0)) / 2;
|
| 424 |
const int j = ((threadIdx.y / np)*cols_per_warp + t*tile_C_KQ_16::I + tile_C_KQ_16::get_i(l0)) / ncols2;
|
| 425 |
|
| 426 |
+
const float2 tmp = __half22float2(tile_mask[j*(c::nbatch_fa/2 + 4) + i]);
|
| 427 |
const int KQ_index = i00/(np*tile_C_KQ_16::J) * ntiles/2 + t;
|
| 428 |
KQ_C_16[KQ_index].x[l0 + 0] += slope*tmp.x;
|
| 429 |
KQ_C_16[KQ_index].x[l0 + 1] += slope*tmp.y;
|
|
|
|
| 434 |
|
| 435 |
// Calculate softmax for each KQ column using the current max. value.
|
| 436 |
// The divisor is stored in KQ_rowsum and will be applied at the end.
|
| 437 |
+
static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size");
|
| 438 |
#pragma unroll
|
| 439 |
+
for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ_16::J); ++k) {
|
| 440 |
#pragma unroll
|
| 441 |
for (int t = 0; t < ntiles/2; ++t) {
|
| 442 |
#pragma unroll
|
|
|
|
| 456 |
}
|
| 457 |
}
|
| 458 |
|
| 459 |
+
static_assert(c::nbatch_fa % (np*tile_C_KQ_16::J) == 0, "bad loop size");
|
| 460 |
#pragma unroll
|
| 461 |
+
for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ_16::J); ++k) {
|
| 462 |
#pragma unroll
|
| 463 |
for (int t = 0; t < ntiles/2; ++t) {
|
| 464 |
#pragma unroll
|
|
|
|
| 487 |
if (ntiles == 1) {
|
| 488 |
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]);
|
| 489 |
#pragma unroll
|
| 490 |
+
for (int i = 0; i < DV/tile_C_VKQ::I; ++i) {
|
| 491 |
#pragma unroll
|
| 492 |
for (int l = 0; l < tile_C_VKQ::ne; ++l) {
|
| 493 |
VKQ_C[i].x[l] *= KQ_max_scale_h2;
|
|
|
|
| 498 |
for (int col = 0; col < cols_per_thread; ++col) {
|
| 499 |
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]);
|
| 500 |
#pragma unroll
|
| 501 |
+
for (int i = 0; i < DV/tile_C_VKQ_16::J; ++i) {
|
| 502 |
#pragma unroll
|
| 503 |
for (int l0 = 0; l0 < tile_C_VKQ_16::ne; l0 += 2) {
|
| 504 |
VKQ_C_16[i*ntiles/2 + col/2].x[l0 + col % 2] *= KQ_max_scale_h2;
|
|
|
|
| 509 |
}
|
| 510 |
|
| 511 |
// Convert KQ C tiles into B tiles for VKQ calculation:
|
| 512 |
+
tile_B B[c::nbatch_fa/(np*2*tile_B::J) * ntiles];
|
| 513 |
tile_B_16 * B_16 = (tile_B_16 *) B;
|
| 514 |
+
static_assert(c::nbatch_fa % (np*2*tile_B::J) == 0, "bad loop size");
|
| 515 |
if (ntiles == 1) {
|
| 516 |
#pragma unroll
|
| 517 |
+
for (int k = 0; k < c::nbatch_fa/(np*2*tile_B::J); ++k) {
|
| 518 |
B[k] = get_transposed(get_half2(KQ_C[k]));
|
| 519 |
}
|
| 520 |
} else {
|
| 521 |
+
for (int k = 0; k < c::nbatch_fa/(np*2*tile_B_16::J); ++k) {
|
| 522 |
#pragma unroll
|
| 523 |
for (int t = 0; t < ntiles/2; ++t) {
|
| 524 |
B_16[k*ntiles/2 + t] = get_half2(KQ_C_16[k*ntiles/2 + t]);
|
|
|
|
| 526 |
}
|
| 527 |
}
|
| 528 |
|
| 529 |
+
if (nstages > 1) {
|
| 530 |
+
// Preload K tile for next iteration:
|
| 531 |
+
constexpr bool use_cp_async = true;
|
| 532 |
+
cp_async_wait_all();
|
| 533 |
+
__syncthreads();
|
| 534 |
+
if (!last_iter) {
|
| 535 |
+
if (ncols2 > 1 || mask_h2) {
|
| 536 |
+
flash_attn_ext_f16_load_mask<ncols1, nwarps, c::nbatch_fa, use_cp_async>
|
| 537 |
+
(mask_h2 + (k_VKQ_0 + c::nbatch_fa)/2, tile_mask, stride_mask);
|
| 538 |
+
}
|
| 539 |
+
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
|
| 540 |
+
(K_h2 + (k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, c::nbatch_K2, stride_K);
|
| 541 |
}
|
|
|
|
| 542 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 543 |
|
|
|
|
| 544 |
#pragma unroll
|
| 545 |
+
for (int i0_start = 0; i0_start < DV; i0_start += 2*c::nbatch_V2) {
|
| 546 |
+
const int i0_stop = i0_start + 2*c::nbatch_V2 < DV ? i0_start + 2*c::nbatch_V2 : DV;
|
| 547 |
+
const int i0_diff = i0_stop - i0_start;
|
| 548 |
+
|
| 549 |
+
if (nstages == 1) {
|
| 550 |
+
constexpr bool use_cp_async = nstages == 1;
|
| 551 |
+
flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
|
| 552 |
+
(V_h2 + k_VKQ_0*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V);
|
| 553 |
+
if (use_cp_async) {
|
| 554 |
+
cp_async_wait_all();
|
| 555 |
+
}
|
| 556 |
+
__syncthreads();
|
| 557 |
+
}
|
| 558 |
+
|
| 559 |
+
// Calculate VKQ tile:
|
| 560 |
+
#pragma unroll
|
| 561 |
+
for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += tile_C_VKQ::I) {
|
| 562 |
+
static_assert((c::nbatch_fa/2) % (np*tile_A::J) == 0, "bad loop size");
|
| 563 |
#pragma unroll
|
| 564 |
+
for (int k00 = 0; k00 < c::nbatch_fa/2; k00 += np*tile_A::J) {
|
| 565 |
+
const int k0 = k00 + (threadIdx.y % np)*tile_A::J;
|
| 566 |
|
| 567 |
+
tile_A A;
|
| 568 |
+
load_ldmatrix_trans(A, tile_V + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
|
| 569 |
+
if (ntiles == 1) {
|
| 570 |
+
mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]);
|
| 571 |
+
} else {
|
| 572 |
#pragma unroll
|
| 573 |
+
for (int t = 0; t < ntiles/2; ++t) {
|
| 574 |
+
// Wide version of VKQ_C is column-major => swap A and B.
|
| 575 |
+
mma(VKQ_C_16[i_VKQ_0/tile_C_VKQ::I * ntiles/2 + t], B_16[k00/(np*tile_A::J) * ntiles/2 + t], A);
|
| 576 |
+
}
|
| 577 |
}
|
| 578 |
}
|
| 579 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 580 |
|
| 581 |
+
if (nstages <= 1) {
|
| 582 |
+
__syncthreads(); // Only needed if tile_K == tile_V.
|
| 583 |
+
}
|
| 584 |
+
}
|
| 585 |
#else
|
| 586 |
GGML_UNUSED(Q_f2); GGML_UNUSED(K_h2); GGML_UNUSED(V_h2);
|
| 587 |
GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup);
|
| 588 |
GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap);
|
| 589 |
+
GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_K); GGML_UNUSED(stride_V);
|
| 590 |
GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K);
|
| 591 |
GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K);
|
| 592 |
GGML_UNUSED(tile_V); GGML_UNUSED(tile_mask); GGML_UNUSED(Q_B);
|
|
|
|
| 596 |
#endif // NEW_MMA_AVAILABLE
|
| 597 |
}
|
| 598 |
|
| 599 |
+
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool needs_fixup, bool is_fixup>
|
| 600 |
static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
| 601 |
const float2 * const __restrict__ Q_f2,
|
| 602 |
const half2 * const __restrict__ K_h2,
|
|
|
|
| 611 |
const int ne02,
|
| 612 |
const int stride_Q1,
|
| 613 |
const int stride_Q2,
|
| 614 |
+
const int stride_K,
|
| 615 |
+
const int stride_V,
|
| 616 |
const int stride_mask,
|
| 617 |
const int jt,
|
| 618 |
const int kb0_start,
|
|
|
|
| 620 |
#ifdef NEW_MMA_AVAILABLE
|
| 621 |
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
| 622 |
|
| 623 |
+
typedef fattn_mma_f16_config<DKQ, DV> c;
|
| 624 |
+
|
| 625 |
+
#ifdef CP_ASYNC_AVAILABLE
|
| 626 |
+
constexpr int nstages = c::nstages_target;
|
| 627 |
+
#else
|
| 628 |
+
constexpr int nstages = 0;
|
| 629 |
+
#endif // CP_ASYNC_AVAILABLE
|
| 630 |
+
|
| 631 |
constexpr int ncols = ncols1 * ncols2;
|
| 632 |
constexpr int cols_per_warp = ntiles * tile_B::I;
|
| 633 |
constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles;
|
|
|
|
| 635 |
|
| 636 |
static_assert(nwarps * (cols_per_warp/ncols2) % ncols1 == 0, "bad nwarps");
|
| 637 |
|
| 638 |
+
constexpr int stride_tile_Q = DKQ/2 + 4;
|
| 639 |
+
constexpr int stride_tile_K = c::nbatch_K2 + 4;
|
| 640 |
+
constexpr int stride_tile_V = c::nbatch_V2 + 4;
|
| 641 |
|
| 642 |
+
constexpr int stride_tile_KV_max = stride_tile_K > stride_tile_V ? stride_tile_K : stride_tile_V;
|
| 643 |
|
| 644 |
+
extern __shared__ half2 tile_Q[];
|
| 645 |
+
half2 * tile_K = c::Q_in_reg ? tile_Q : tile_Q + ncols * stride_tile_Q;
|
| 646 |
+
half2 * tile_V = nstages > 1 ? tile_K + c::nbatch_fa * stride_tile_K : tile_K;
|
| 647 |
+
half2 * tile_mask = nstages > 1 ? tile_V + c::nbatch_fa * stride_tile_V : tile_V + c::nbatch_fa * stride_tile_KV_max;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 648 |
|
| 649 |
+
tile_B Q_B[(c::Q_in_reg ? DKQ/(2*tile_B::J) : 1) * ntiles];
|
| 650 |
+
tile_C_VKQ VKQ_C[DV/tile_C_VKQ::I * ntiles];
|
| 651 |
|
| 652 |
tile_B_16 * Q_B_16 = (tile_B_16 *) Q_B;
|
| 653 |
tile_C_VKQ_16 * VKQ_C_16 = (tile_C_VKQ_16 *) VKQ_C;
|
|
|
|
| 659 |
KQ_max[col] = -FLT_MAX/2.0f;
|
| 660 |
}
|
| 661 |
|
| 662 |
+
// Load Q data into tile_Q, either temporarily or permanently.
|
| 663 |
+
// Q in registers is faster, but register pressure is the biggest bottleneck.
|
| 664 |
// The loading is done with decreasing granularity for D for better memory bandwidth.
|
| 665 |
const half2 scale_h2 = make_half2(scale, scale);
|
| 666 |
#pragma unroll
|
| 667 |
for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
|
| 668 |
+
const int k0_start = stride_k == WARP_SIZE ? 0 : DKQ/2 - (DKQ/2) % (2*stride_k);
|
| 669 |
+
const int k0_stop = DKQ/2 - (DKQ/2) % (1*stride_k);
|
| 670 |
const int stride_jc = WARP_SIZE / stride_k;
|
| 671 |
|
| 672 |
if (k0_start == k0_stop) {
|
|
|
|
| 690 |
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
|
| 691 |
|
| 692 |
const float2 tmp = Q_f2[(jt*ncols1 + j)*stride_Q1 + c*stride_Q2 + k];
|
| 693 |
+
tile_Q[jc*stride_tile_Q + k] = scale_h2 * make_half2(tmp.x, tmp.y);
|
| 694 |
}
|
| 695 |
} else {
|
| 696 |
#pragma unroll
|
| 697 |
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
|
| 698 |
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
|
| 699 |
|
| 700 |
+
tile_Q[jc*stride_tile_Q + k] = make_half2(0.0f, 0.0f);
|
| 701 |
}
|
| 702 |
}
|
| 703 |
}
|
|
|
|
| 705 |
|
| 706 |
__syncthreads();
|
| 707 |
|
| 708 |
+
if (c::Q_in_reg) {
|
| 709 |
const int j0 = (threadIdx.y / np) * cols_per_warp;
|
| 710 |
|
| 711 |
#pragma unroll
|
| 712 |
+
for (int k0 = 0; k0 < DKQ/2; k0 += tile_B::J) {
|
| 713 |
if (ntiles == 1) {
|
| 714 |
+
load_ldmatrix(Q_B[k0/tile_B::J], tile_Q + j0*stride_tile_Q + k0, stride_tile_Q);
|
| 715 |
} else {
|
| 716 |
#pragma unroll
|
| 717 |
for (int t = 0; t < ntiles/2; ++t) {
|
| 718 |
load_ldmatrix(Q_B_16[k0/tile_B_16::J * ntiles/2 + t],
|
| 719 |
+
tile_Q + (j0 + t*tile_B_16::I)*stride_tile_Q + k0, stride_tile_Q);
|
| 720 |
}
|
| 721 |
}
|
| 722 |
}
|
|
|
|
| 724 |
|
| 725 |
__syncthreads();
|
| 726 |
|
| 727 |
+
// Preload mask and K data for first iteration when using cp_async with multiple stages:
|
| 728 |
+
if constexpr (nstages > 1) {
|
| 729 |
+
static_assert(c::nbatch_K2 == DKQ/2, "batching not implemented for multi-stage pipeline");
|
| 730 |
+
constexpr bool use_cp_async = true;
|
| 731 |
+
if (ncols2 > 1 || mask_h2) {
|
| 732 |
+
flash_attn_ext_f16_load_mask<ncols1, nwarps, c::nbatch_fa, use_cp_async>
|
| 733 |
+
(mask_h2 + kb0_start*c::nbatch_fa/2, tile_mask, stride_mask);
|
| 734 |
+
}
|
| 735 |
+
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
|
| 736 |
+
(K_h2 + kb0_start*c::nbatch_fa*stride_K, tile_K, c::nbatch_K2, stride_K);
|
| 737 |
}
|
|
|
|
|
|
|
| 738 |
|
| 739 |
// Iterate over ne11 == previous tokens:
|
| 740 |
for (int kb0 = kb0_start; kb0 < kb0_stop-1; ++kb0) {
|
| 741 |
constexpr bool last_iter = false;
|
| 742 |
+
flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup, last_iter>
|
| 743 |
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
|
| 744 |
+
ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
|
| 745 |
}
|
| 746 |
{ // kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
|
| 747 |
constexpr bool last_iter = true;
|
| 748 |
+
flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup, last_iter>
|
| 749 |
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
|
| 750 |
+
ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1);
|
| 751 |
}
|
| 752 |
|
| 753 |
+
// With multi-stage loading there is no __syncthreads at the end of the iter,
|
| 754 |
// there can be a race condition on shared memory access for combining/writing back results.
|
| 755 |
+
if (nstages > 1 && nwarps*cols_per_warp > c::nbatch_fa) {
|
|
|
|
| 756 |
__syncthreads();
|
| 757 |
}
|
|
|
|
| 758 |
|
| 759 |
// Finally, sum up partial KQ rowsums.
|
| 760 |
// The partial sums are spread across 8/4 threads each, does not need full reduce.
|
|
|
|
| 770 |
}
|
| 771 |
}
|
| 772 |
|
| 773 |
+
// Combine VKQ accumulator values if np > 1.
|
| 774 |
+
// It's also faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
|
| 775 |
+
// So also write VKQ accumulators to shared memory in column-major format if np == 1.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 776 |
|
| 777 |
+
constexpr int nbatch_combine = c::Q_in_reg ? DV/2 : DV/4;
|
| 778 |
+
constexpr int tile_stride = nbatch_combine + 4;
|
| 779 |
+
static_assert((DV/2) % nbatch_combine == 0, "bad nbatch_combine");
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 780 |
|
| 781 |
if constexpr (ntiles == 1) {
|
| 782 |
const int jc_cwmo = (threadIdx.x % (2*tile_C_VKQ::J)) / tile_C_VKQ::J; // jc combine write meta offset
|
|
|
|
| 785 |
|
| 786 |
if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*tile_C_VKQ::J) {
|
| 787 |
// Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale.
|
| 788 |
+
((float2 *) tile_Q)[jc_cwm*(tile_stride/2) + nbatch_combine/2] = KQ_cmr;
|
| 789 |
}
|
| 790 |
|
| 791 |
__syncthreads();
|
|
|
|
| 810 |
|
| 811 |
if (((!needs_fixup && !is_fixup) || np > 1) && (ntiles == 4 || threadIdx.x % 4 < cols_per_thread)) {
|
| 812 |
// Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale.
|
| 813 |
+
((float2 *) tile_Q)[jc_cwm*(tile_stride/2) + nbatch_combine/2] = KQ_cmr;
|
| 814 |
}
|
| 815 |
|
| 816 |
__syncthreads();
|
|
|
|
| 837 |
constexpr int nmeta = np*cols_per_warp >= WARP_SIZE ? np*cols_per_warp/WARP_SIZE : 1;
|
| 838 |
|
| 839 |
const int jc_meta = threadIdx.y*cols_per_warp + (np*cols_per_warp < WARP_SIZE ? threadIdx.x % (np*cols_per_warp) : threadIdx.x);
|
| 840 |
+
float2 * const meta_ptr = ((float2 *) tile_Q) + jc_meta*(tile_stride/2) + nbatch_combine/2;
|
| 841 |
float2 meta[nmeta];
|
| 842 |
#pragma unroll
|
| 843 |
for (int imeta = 0; imeta < nmeta; ++imeta) {
|
| 844 |
+
meta[imeta] = meta_ptr[imeta * WARP_SIZE * tile_stride/2];
|
| 845 |
}
|
| 846 |
|
| 847 |
float KQ_cmn = meta[0].x; // KQ combine max new, max between all parallel warps.
|
|
|
|
| 851 |
}
|
| 852 |
#pragma unroll
|
| 853 |
for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) {
|
| 854 |
+
if (offset < WARP_SIZE) {
|
| 855 |
+
KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE));
|
| 856 |
}
|
|
|
|
| 857 |
}
|
| 858 |
|
| 859 |
float KQ_cms[nmeta]; // KQ combine max scale per warp.
|
|
|
|
| 869 |
}
|
| 870 |
#pragma unroll
|
| 871 |
for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) {
|
| 872 |
+
if (offset < WARP_SIZE) {
|
| 873 |
+
KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE);
|
| 874 |
}
|
|
|
|
| 875 |
}
|
| 876 |
|
| 877 |
// Write back combined meta data:
|
|
|
|
| 879 |
for (int imeta = 0; imeta < nmeta; ++imeta) {
|
| 880 |
if (np*cols_per_warp >= WARP_SIZE || threadIdx.x < np*cols_per_warp) {
|
| 881 |
// Combined KQ max scale + rowsum.
|
| 882 |
+
meta_ptr[imeta * WARP_SIZE * tile_stride/2] = make_float2(KQ_cms[imeta], KQ_crs);
|
| 883 |
}
|
| 884 |
}
|
| 885 |
|
|
|
|
| 895 |
}
|
| 896 |
}
|
| 897 |
|
| 898 |
+
#pragma unroll
|
| 899 |
+
for (int k00 = 0; k00 < DV/2; k00 += nbatch_combine) {
|
| 900 |
+
if (ntiles == 1) {
|
| 901 |
+
const int jc_cwd = threadIdx.y*tile_B::I + tile_B::get_i(-1); // jc combine write data
|
| 902 |
+
#pragma unroll
|
| 903 |
+
for (int k0 = 0; k0 < nbatch_combine; k0 += tile_B::J) {
|
| 904 |
+
const tile_B B = get_transposed(VKQ_C[(k00 + k0)/tile_B::J]); // Conversion of C to B matrix puts it in column-major format.
|
|
|
|
| 905 |
|
| 906 |
#pragma unroll
|
| 907 |
+
for (int l = 0; l < tile_B::ne; ++l) {
|
| 908 |
+
const int k = k0 + tile_B::get_j(l);
|
|
|
|
|
|
|
| 909 |
|
| 910 |
+
tile_Q[jc_cwd*tile_stride + k] = B.x[l];
|
| 911 |
+
}
|
| 912 |
}
|
| 913 |
+
} else {
|
| 914 |
+
#pragma unroll
|
| 915 |
+
for (int t = 0; t < ntiles/2; ++t) {
|
| 916 |
+
const int j0 = threadIdx.y*cols_per_warp + t*tile_C_VKQ_16::I;
|
| 917 |
+
#pragma unroll
|
| 918 |
+
for (int k0 = 0; k0 < nbatch_combine; k0 += tile_C_VKQ_16::J) {
|
| 919 |
#pragma unroll
|
| 920 |
+
for (int l = 0; l < tile_C_VKQ_16::ne; ++l) {
|
| 921 |
+
const int j = j0 + tile_C_VKQ_16::get_i(l);
|
| 922 |
+
const int k = k0 + tile_C_VKQ_16::get_j(l);
|
| 923 |
|
| 924 |
+
tile_Q[j*tile_stride + k] = VKQ_C_16[(k00 + k0)/tile_C_VKQ_16::J * ntiles/2 + t].x[l];
|
| 925 |
+
}
|
| 926 |
}
|
| 927 |
+
}
|
| 928 |
+
}
|
| 929 |
|
| 930 |
+
__syncthreads();
|
| 931 |
+
|
| 932 |
+
if (np == 1 || threadIdx.y % np == 0) {
|
| 933 |
+
// The first 2*2*gridDim.x*ncols floats in dstk_fixup are for storing max. values and row sums.
|
| 934 |
+
// The values after that are for the partial results of the individual blocks.
|
| 935 |
+
float2 * dstk_fixup_data = dstk_fixup + gridDim.x*(2*ncols) + blockIdx.x*(ncols*(DV/2));
|
| 936 |
|
| 937 |
+
#pragma unroll
|
| 938 |
+
for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
|
| 939 |
+
const int k0_start = stride_k == WARP_SIZE ? 0 : nbatch_combine - nbatch_combine % (2*stride_k);
|
| 940 |
+
const int k0_stop = nbatch_combine - nbatch_combine % (1*stride_k);
|
| 941 |
+
const int stride_jc = WARP_SIZE / stride_k;
|
| 942 |
|
| 943 |
+
if (k0_start == k0_stop) {
|
| 944 |
continue;
|
| 945 |
}
|
| 946 |
|
|
|
|
| 947 |
#pragma unroll
|
| 948 |
+
for (int jc0_dst = 0; jc0_dst < ncols; jc0_dst += (nwarps/np)*stride_jc) {
|
| 949 |
+
const int jc_dst = jc0_dst + (threadIdx.y/np)*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
|
| 950 |
|
| 951 |
+
if (jc0_dst + (nwarps/np)*stride_jc > ncols && jc_dst >= ncols) {
|
| 952 |
+
break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 953 |
}
|
| 954 |
|
| 955 |
+
const int jc_tile_K = (jc_dst/cols_per_warp)*(np*cols_per_warp) + jc_dst % cols_per_warp;
|
| 956 |
+
|
| 957 |
+
const int j_dst = jc_dst / ncols2;
|
| 958 |
+
const int c_dst = jc_dst % ncols2;
|
| 959 |
+
|
| 960 |
+
if (!is_fixup && jt*ncols1 + j_dst >= ne01) {
|
| 961 |
+
continue;
|
| 962 |
}
|
| 963 |
|
| 964 |
+
const float * meta_j = (const float *) tile_Q + jc_tile_K*tile_stride + nbatch_combine;
|
| 965 |
+
#pragma unroll
|
| 966 |
+
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
|
| 967 |
+
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
|
| 968 |
+
|
| 969 |
+
float2 dstk_val = make_float2(0.0f, 0.0f);
|
| 970 |
+
#pragma unroll
|
| 971 |
+
for (int ip = 0; ip < np; ++ip) {
|
| 972 |
+
const float KQ_crs = np == 1 ? 1.0f : meta_j[ip*cols_per_warp * tile_stride + 0];
|
| 973 |
+
const float2 dstk_val_add = __half22float2(tile_Q[(jc_tile_K + ip*cols_per_warp) * tile_stride + k]);
|
| 974 |
+
dstk_val.x += dstk_val_add.x*KQ_crs;
|
| 975 |
+
dstk_val.y += dstk_val_add.y*KQ_crs;
|
| 976 |
+
}
|
| 977 |
+
|
| 978 |
+
if (!needs_fixup && !is_fixup) {
|
| 979 |
+
const float KQ_rowsum_j = meta_j[1];
|
| 980 |
+
dstk_val.x /= KQ_rowsum_j;
|
| 981 |
+
dstk_val.y /= KQ_rowsum_j;
|
| 982 |
+
}
|
| 983 |
+
|
| 984 |
+
if (is_fixup) {
|
| 985 |
+
dstk_fixup_data[jc_dst*(DV/2) + k00 + k] = dstk_val;
|
| 986 |
+
} else {
|
| 987 |
+
dstk[((jt*ncols1 + j_dst)*ne02 + c_dst)*(DV/2) + k00 + k] = dstk_val;
|
| 988 |
+
}
|
| 989 |
}
|
| 990 |
}
|
| 991 |
}
|
| 992 |
}
|
| 993 |
+
if (np > 1) {
|
| 994 |
+
__syncthreads();
|
| 995 |
+
}
|
|
|
|
| 996 |
}
|
| 997 |
#else
|
| 998 |
GGML_UNUSED(Q_f2); GGML_UNUSED(K_h2); GGML_UNUSED(V_h2);
|
| 999 |
GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup);
|
| 1000 |
GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap);
|
| 1001 |
GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_Q1);
|
| 1002 |
+
GGML_UNUSED(stride_Q2); GGML_UNUSED(stride_K); GGML_UNUSED(stride_V); GGML_UNUSED(stride_mask);
|
| 1003 |
GGML_UNUSED(jt); GGML_UNUSED(kb0_start); GGML_UNUSED(kb0_stop);
|
| 1004 |
NO_DEVICE_CODE;
|
| 1005 |
#endif // NEW_MMA_AVAILABLE
|
| 1006 |
}
|
| 1007 |
|
| 1008 |
+
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap>
|
| 1009 |
+
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
| 1010 |
static __global__ void flash_attn_ext_f16(
|
| 1011 |
const char * __restrict__ Q,
|
| 1012 |
const char * __restrict__ K,
|
|
|
|
| 1046 |
#if defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
|
| 1047 |
|
| 1048 |
// Skip unused kernel variants for faster compilation:
|
| 1049 |
+
if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) {
|
| 1050 |
NO_DEVICE_CODE;
|
| 1051 |
return;
|
| 1052 |
}
|
| 1053 |
|
| 1054 |
+
typedef fattn_mma_f16_config<DKQ, DV> c;
|
| 1055 |
+
|
| 1056 |
+
static_assert(FATTN_KQ_STRIDE % fattn_mma_f16_config<DKQ, DV>::nbatch_fa == 0, "bad nbatch_fa");
|
| 1057 |
|
| 1058 |
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
| 1059 |
|
| 1060 |
const int stride_Q1 = nb01 / sizeof(float2);
|
| 1061 |
const int stride_Q2 = nb02 / sizeof(float2);
|
| 1062 |
+
const int stride_K = nb11 / sizeof(half2);
|
| 1063 |
+
const int stride_V = nb21 / sizeof(half2);
|
| 1064 |
const int stride_mask = nb31 / sizeof(half2);
|
| 1065 |
|
| 1066 |
const int iter_k = ne11 / FATTN_KQ_STRIDE;
|
| 1067 |
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
|
| 1068 |
|
| 1069 |
+
constexpr int kb_niter = FATTN_KQ_STRIDE / c::nbatch_fa; // Number of kernel iterations per assigned KQ slice.
|
| 1070 |
|
| 1071 |
// kbc == k block continuous, current index in continuous ijk space.
|
| 1072 |
int kbc = (blockIdx.x + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
|
|
|
|
| 1085 |
|
| 1086 |
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
|
| 1087 |
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
|
| 1088 |
+
const half2 * V_h2 = (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
|
| 1089 |
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
|
| 1090 |
+
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
|
| 1091 |
|
| 1092 |
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
|
| 1093 |
|
|
|
|
| 1097 |
constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
|
| 1098 |
if (kb0_start == 0) {
|
| 1099 |
constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
|
| 1100 |
+
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup>
|
| 1101 |
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
|
| 1102 |
+
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
| 1103 |
} else {
|
| 1104 |
constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
|
| 1105 |
+
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup>
|
| 1106 |
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
|
| 1107 |
+
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
| 1108 |
}
|
| 1109 |
|
| 1110 |
kbc += iter_k;
|
|
|
|
| 1123 |
|
| 1124 |
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
|
| 1125 |
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
|
| 1126 |
+
const half2 * V_h2 = (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio)); // K and V have same shape
|
| 1127 |
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
|
| 1128 |
+
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
|
| 1129 |
|
| 1130 |
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
|
| 1131 |
|
|
|
|
| 1134 |
|
| 1135 |
constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
|
| 1136 |
constexpr bool needs_fixup = false;
|
| 1137 |
+
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup>
|
| 1138 |
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
|
| 1139 |
+
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
| 1140 |
#else
|
| 1141 |
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
|
| 1142 |
GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
|
|
|
|
| 1152 |
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
|
| 1153 |
}
|
| 1154 |
|
| 1155 |
+
template <int DKQ, int DV, int ncols1, int ncols2>
|
| 1156 |
void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 1157 |
+
const ggml_tensor * KQV = dst;
|
| 1158 |
+
const int id = ggml_cuda_get_device();
|
| 1159 |
+
const int cc = ggml_cuda_info().devices[id].cc;
|
| 1160 |
+
|
| 1161 |
+
typedef fattn_mma_f16_config<DKQ, DV> c;
|
| 1162 |
+
|
| 1163 |
+
constexpr int nbatch_K2 = c::nbatch_K2 < 1 ? DKQ/2 : c::nbatch_K2;
|
| 1164 |
+
constexpr int nbatch_V2 = c::nbatch_V2 < 1 ? DV /2 : c::nbatch_V2;
|
| 1165 |
+
constexpr int nbatch_combine = c::nbatch_combine < 1 ? DV /2 : c::nbatch_combine;
|
| 1166 |
+
|
| 1167 |
+
const int nstages = cp_async_available(cc) ? c::nstages_target : 0;
|
| 1168 |
+
|
| 1169 |
constexpr int ncols = ncols1 * ncols2;
|
| 1170 |
+
constexpr int ntiles = ncols <= 8 ? 1 : 2; // Number of tiles per warp.
|
|
|
|
|
|
|
| 1171 |
constexpr int cols_per_warp = ntiles * tile_B::I;
|
| 1172 |
+
constexpr int nwarps_max_x = ncols / cols_per_warp;
|
| 1173 |
+
constexpr int nwarps_max_y = c::nbatch_fa / tile_A::I;
|
| 1174 |
+
constexpr int nwarps = nwarps_max_x*nwarps_max_y <= c::nwarps_max ? nwarps_max_x*nwarps_max_y : c::nwarps_max;
|
| 1175 |
|
| 1176 |
+
static_assert(DKQ % tile_B::J == 0, "bad DKQ");
|
| 1177 |
+
static_assert(DV % tile_A::J == 0, "bad DV");
|
| 1178 |
static_assert(ncols % cols_per_warp == 0, "bad ncols");
|
| 1179 |
|
| 1180 |
+
const size_t nbytes_shared_KV_1stage = c::nbatch_fa * std::max(c::nbatch_K2 + 4, c::nbatch_V2 + 4) * sizeof(half2);
|
| 1181 |
+
const size_t nbytes_shared_KV_2stage = c::nbatch_fa * (c::nbatch_K2 + 4 + c::nbatch_V2 + 4) * sizeof(half2);
|
| 1182 |
+
const size_t nbytes_shared_Q = ncols * (DKQ/2 + 4) * sizeof(half2);
|
| 1183 |
+
const size_t nbytes_shared_mask = ncols1 * (c::nbatch_fa/2 + 4) * sizeof(half2);
|
| 1184 |
+
const size_t nbytes_shared_combine = nwarps*cols_per_warp * (nbatch_combine + 4) * sizeof(half2);
|
| 1185 |
|
| 1186 |
+
const size_t nbytes_shared_KV = nstages <= 1 ? nbytes_shared_KV_1stage : nbytes_shared_KV_2stage;
|
| 1187 |
|
| 1188 |
+
const size_t nbytes_shared_total = std::max(nbytes_shared_combine, c::Q_in_reg ?
|
| 1189 |
+
std::max(nbytes_shared_Q, nbytes_shared_KV + nbytes_shared_mask) :
|
| 1190 |
+
nbytes_shared_Q + nbytes_shared_KV + nbytes_shared_mask);
|
|
|
|
|
|
|
| 1191 |
|
| 1192 |
float logit_softcap;
|
| 1193 |
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
|
|
|
| 1195 |
fattn_kernel_t fattn_kernel;
|
| 1196 |
if (logit_softcap == 0.0f) {
|
| 1197 |
constexpr bool use_logit_softcap = false;
|
| 1198 |
+
fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap>;
|
| 1199 |
+
|
| 1200 |
+
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
| 1201 |
+
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
|
| 1202 |
+
if (!shared_memory_limit_raised[id]) {
|
| 1203 |
+
CUDA_CHECK(cudaFuncSetAttribute(fattn_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total));
|
| 1204 |
+
shared_memory_limit_raised[id] = true;
|
| 1205 |
+
}
|
| 1206 |
+
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
| 1207 |
} else {
|
| 1208 |
constexpr bool use_logit_softcap = true;
|
| 1209 |
+
fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap>;
|
| 1210 |
+
|
| 1211 |
+
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
| 1212 |
+
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
|
| 1213 |
+
if (!shared_memory_limit_raised[id]) {
|
| 1214 |
+
CUDA_CHECK(cudaFuncSetAttribute(fattn_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total));
|
| 1215 |
+
shared_memory_limit_raised[id] = true;
|
| 1216 |
+
}
|
| 1217 |
+
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
| 1218 |
}
|
| 1219 |
|
| 1220 |
+
launch_fattn<DV, ncols1, ncols2>
|
| 1221 |
(ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, FATTN_KQ_STRIDE, true, true, true);
|
| 1222 |
}
|
| 1223 |
|
| 1224 |
|
| 1225 |
+
#define DECL_FATTN_MMA_F16_CASE(DKQ, DV, ncols1, ncols2) \
|
| 1226 |
+
template void ggml_cuda_flash_attn_ext_mma_f16_case \
|
| 1227 |
+
<DKQ, DV, ncols1, ncols2>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
|
| 1228 |
+
|
| 1229 |
+
#define DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(DKQ, DV, ncols) \
|
| 1230 |
+
extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/ 1, 1); \
|
| 1231 |
+
extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/ 2, 2); \
|
| 1232 |
+
extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/ 4, 4); \
|
| 1233 |
+
extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/ 8, 8); \
|
| 1234 |
+
extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/16, 16); \
|
| 1235 |
+
|
| 1236 |
+
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 64, 8)
|
| 1237 |
+
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 80, 8)
|
| 1238 |
+
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 96, 8)
|
| 1239 |
+
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 112, 8)
|
| 1240 |
+
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128, 8)
|
| 1241 |
+
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 8)
|
| 1242 |
+
|
| 1243 |
+
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 64, 16)
|
| 1244 |
+
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 80, 16)
|
| 1245 |
+
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 96, 16)
|
| 1246 |
+
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 112, 16)
|
| 1247 |
+
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128, 16)
|
| 1248 |
+
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 16)
|
| 1249 |
+
|
| 1250 |
+
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 64, 32)
|
| 1251 |
+
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 80, 32)
|
| 1252 |
+
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 96, 32)
|
| 1253 |
+
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 112, 32)
|
| 1254 |
+
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128, 32)
|
| 1255 |
+
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 32)
|
| 1256 |
+
|
| 1257 |
+
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 64, 64)
|
| 1258 |
+
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 80, 64)
|
| 1259 |
+
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 96, 64)
|
| 1260 |
+
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 112, 64)
|
| 1261 |
+
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128, 64)
|
| 1262 |
+
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 64)
|
| 1263 |
+
|
| 1264 |
+
// The number of viable configurations for Deepseek is very limited:
|
| 1265 |
+
extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16);
|
| 1266 |
+
extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16);
|
| 1267 |
+
extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16);
|
|
|
|
|
|
ggml/src/ggml-cuda/fattn-tile-f16.cu
CHANGED
|
@@ -307,7 +307,7 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
|
|
| 307 |
constexpr int nwarps = 8;
|
| 308 |
constexpr size_t nbytes_shared = 0;
|
| 309 |
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, use_logit_softcap>;
|
| 310 |
-
launch_fattn<D, cols_per_block, 1
|
| 311 |
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F16, true, true, false);
|
| 312 |
} break;
|
| 313 |
case 128: {
|
|
@@ -315,7 +315,7 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
|
|
| 315 |
constexpr int nwarps = 8;
|
| 316 |
constexpr size_t nbytes_shared = 0;
|
| 317 |
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, use_logit_softcap>;
|
| 318 |
-
launch_fattn<D, cols_per_block, 1
|
| 319 |
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F16, true, true, false);
|
| 320 |
} break;
|
| 321 |
default: {
|
|
|
|
| 307 |
constexpr int nwarps = 8;
|
| 308 |
constexpr size_t nbytes_shared = 0;
|
| 309 |
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, use_logit_softcap>;
|
| 310 |
+
launch_fattn<D, cols_per_block, 1>
|
| 311 |
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F16, true, true, false);
|
| 312 |
} break;
|
| 313 |
case 128: {
|
|
|
|
| 315 |
constexpr int nwarps = 8;
|
| 316 |
constexpr size_t nbytes_shared = 0;
|
| 317 |
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, use_logit_softcap>;
|
| 318 |
+
launch_fattn<D, cols_per_block, 1>
|
| 319 |
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F16, true, true, false);
|
| 320 |
} break;
|
| 321 |
default: {
|
ggml/src/ggml-cuda/fattn-tile-f32.cu
CHANGED
|
@@ -318,7 +318,7 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
|
|
| 318 |
constexpr int nwarps = 8;
|
| 319 |
constexpr size_t nbytes_shared = 0;
|
| 320 |
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, use_logit_softcap>;
|
| 321 |
-
launch_fattn<D, cols_per_block, 1
|
| 322 |
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F32, true, true, false);
|
| 323 |
} break;
|
| 324 |
case 128: {
|
|
@@ -326,7 +326,7 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
|
|
| 326 |
constexpr int nwarps = 8;
|
| 327 |
constexpr size_t nbytes_shared = 0;
|
| 328 |
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, use_logit_softcap>;
|
| 329 |
-
launch_fattn<D, cols_per_block, 1
|
| 330 |
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F32, true, true, false);
|
| 331 |
} break;
|
| 332 |
default: {
|
|
|
|
| 318 |
constexpr int nwarps = 8;
|
| 319 |
constexpr size_t nbytes_shared = 0;
|
| 320 |
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, use_logit_softcap>;
|
| 321 |
+
launch_fattn<D, cols_per_block, 1>
|
| 322 |
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F32, true, true, false);
|
| 323 |
} break;
|
| 324 |
case 128: {
|
|
|
|
| 326 |
constexpr int nwarps = 8;
|
| 327 |
constexpr size_t nbytes_shared = 0;
|
| 328 |
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, use_logit_softcap>;
|
| 329 |
+
launch_fattn<D, cols_per_block, 1>
|
| 330 |
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F32, true, true, false);
|
| 331 |
} break;
|
| 332 |
default: {
|
ggml/src/ggml-cuda/fattn-vec-f16.cuh
CHANGED
|
@@ -315,7 +315,7 @@ void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx,
|
|
| 315 |
constexpr bool need_f16_K = D != 128;
|
| 316 |
constexpr bool need_f16_V = D != 128 && D != 64;
|
| 317 |
constexpr size_t nbytes_shared = 0;
|
| 318 |
-
launch_fattn<D, cols_per_block, 1
|
| 319 |
}
|
| 320 |
|
| 321 |
template <int D, ggml_type type_K, ggml_type type_V>
|
|
|
|
| 315 |
constexpr bool need_f16_K = D != 128;
|
| 316 |
constexpr bool need_f16_V = D != 128 && D != 64;
|
| 317 |
constexpr size_t nbytes_shared = 0;
|
| 318 |
+
launch_fattn<D, cols_per_block, 1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false);
|
| 319 |
}
|
| 320 |
|
| 321 |
template <int D, ggml_type type_K, ggml_type type_V>
|
ggml/src/ggml-cuda/fattn-vec-f32.cuh
CHANGED
|
@@ -310,7 +310,7 @@ void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx,
|
|
| 310 |
constexpr bool need_f16_K = D != 128;
|
| 311 |
constexpr bool need_f16_V = D != 128 && D != 64;
|
| 312 |
constexpr size_t nbytes_shared = 0;
|
| 313 |
-
launch_fattn<D, cols_per_block, 1
|
| 314 |
}
|
| 315 |
|
| 316 |
template <int D, ggml_type type_K, ggml_type type_V>
|
|
|
|
| 310 |
constexpr bool need_f16_K = D != 128;
|
| 311 |
constexpr bool need_f16_V = D != 128 && D != 64;
|
| 312 |
constexpr size_t nbytes_shared = 0;
|
| 313 |
+
launch_fattn<D, cols_per_block, 1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false);
|
| 314 |
}
|
| 315 |
|
| 316 |
template <int D, ggml_type type_K, ggml_type type_V>
|
ggml/src/ggml-cuda/fattn-wmma-f16.cu
CHANGED
|
@@ -490,7 +490,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
|
|
| 490 |
fattn_kernel = flash_attn_ext_f16<
|
| 491 |
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), KQ_acc_t, use_logit_softcap>;
|
| 492 |
}
|
| 493 |
-
launch_fattn<D, cols_per_block, 1
|
| 494 |
}
|
| 495 |
|
| 496 |
void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
|
|
| 490 |
fattn_kernel = flash_attn_ext_f16<
|
| 491 |
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), KQ_acc_t, use_logit_softcap>;
|
| 492 |
}
|
| 493 |
+
launch_fattn<D, cols_per_block, 1>(ctx, dst, fattn_kernel, nwarps, 0, FATTN_KQ_STRIDE, true, true, false, warp_size);
|
| 494 |
}
|
| 495 |
|
| 496 |
void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
ggml/src/ggml-cuda/fattn.cu
CHANGED
|
@@ -8,58 +8,32 @@
|
|
| 8 |
#include "fattn-wmma-f16.cuh"
|
| 9 |
#include "fattn.cuh"
|
| 10 |
|
| 11 |
-
template <int
|
| 12 |
static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 13 |
const ggml_tensor * Q = dst->src[0];
|
| 14 |
|
| 15 |
-
if (
|
| 16 |
-
|
| 17 |
-
|
|
|
|
|
|
|
| 18 |
}
|
| 19 |
|
| 20 |
if (Q->ne[1] <= 16/ncols2) {
|
| 21 |
-
ggml_cuda_flash_attn_ext_mma_f16_case<
|
| 22 |
return;
|
| 23 |
}
|
| 24 |
|
| 25 |
if (Q->ne[1] <= 32/ncols2) {
|
| 26 |
-
ggml_cuda_flash_attn_ext_mma_f16_case<
|
| 27 |
return;
|
| 28 |
}
|
| 29 |
|
| 30 |
-
ggml_cuda_flash_attn_ext_mma_f16_case<
|
| 31 |
-
}
|
| 32 |
-
|
| 33 |
-
template <int ncols2>
|
| 34 |
-
static void ggml_cuda_flash_attn_ext_mma_f16_switch_hs(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 35 |
-
const ggml_tensor * Q = dst->src[0];
|
| 36 |
-
|
| 37 |
-
switch (Q->ne[0]) {
|
| 38 |
-
case 64:
|
| 39 |
-
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 64, ncols2>(ctx, dst);
|
| 40 |
-
break;
|
| 41 |
-
case 80:
|
| 42 |
-
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 80, ncols2>(ctx, dst);
|
| 43 |
-
break;
|
| 44 |
-
case 96:
|
| 45 |
-
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 96, ncols2>(ctx, dst);
|
| 46 |
-
break;
|
| 47 |
-
case 112:
|
| 48 |
-
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<112, ncols2>(ctx, dst);
|
| 49 |
-
break;
|
| 50 |
-
case 128:
|
| 51 |
-
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<128, ncols2>(ctx, dst);
|
| 52 |
-
break;
|
| 53 |
-
case 256:
|
| 54 |
-
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<256, ncols2>(ctx, dst);
|
| 55 |
-
break;
|
| 56 |
-
default:
|
| 57 |
-
GGML_ABORT("fatal error");
|
| 58 |
-
break;
|
| 59 |
-
}
|
| 60 |
}
|
| 61 |
|
| 62 |
-
|
|
|
|
| 63 |
const ggml_tensor * KQV = dst;
|
| 64 |
const ggml_tensor * Q = dst->src[0];
|
| 65 |
const ggml_tensor * K = dst->src[1];
|
|
@@ -68,27 +42,79 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
|
|
| 68 |
float max_bias = 0.0f;
|
| 69 |
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
|
| 70 |
|
| 71 |
-
const
|
| 72 |
|
| 73 |
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
|
| 74 |
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
| 75 |
|
| 76 |
if (use_gqa_opt && gqa_ratio % 8 == 0) {
|
| 77 |
-
|
| 78 |
return;
|
| 79 |
}
|
| 80 |
|
| 81 |
-
if (use_gqa_opt && gqa_ratio ==
|
| 82 |
-
|
| 83 |
return;
|
| 84 |
}
|
| 85 |
|
| 86 |
-
if (use_gqa_opt && gqa_ratio ==
|
| 87 |
-
|
| 88 |
return;
|
| 89 |
}
|
| 90 |
|
| 91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
}
|
| 93 |
|
| 94 |
#define FATTN_VEC_F16_CASE(D, type_K, type_V) \
|
|
@@ -299,7 +325,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|
| 299 |
const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations
|
| 300 |
const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16;
|
| 301 |
const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && cc < GGML_CUDA_CC_ADA_LOVELACE && !mma_needs_data_conversion;
|
| 302 |
-
const bool can_use_vector_kernel = Q->ne[0] % (2*warp_size) == 0;
|
| 303 |
if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) {
|
| 304 |
if (prec == GGML_PREC_DEFAULT) {
|
| 305 |
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
|
|
|
| 8 |
#include "fattn-wmma-f16.cuh"
|
| 9 |
#include "fattn.cuh"
|
| 10 |
|
| 11 |
+
template <int DKQ, int DV, int ncols2>
|
| 12 |
static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 13 |
const ggml_tensor * Q = dst->src[0];
|
| 14 |
|
| 15 |
+
if constexpr (ncols2 <= 8) {
|
| 16 |
+
if (Q->ne[1] <= 8/ncols2) {
|
| 17 |
+
ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 8/ncols2, ncols2>(ctx, dst);
|
| 18 |
+
return;
|
| 19 |
+
}
|
| 20 |
}
|
| 21 |
|
| 22 |
if (Q->ne[1] <= 16/ncols2) {
|
| 23 |
+
ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 16/ncols2, ncols2>(ctx, dst);
|
| 24 |
return;
|
| 25 |
}
|
| 26 |
|
| 27 |
if (Q->ne[1] <= 32/ncols2) {
|
| 28 |
+
ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 32/ncols2, ncols2>(ctx, dst);
|
| 29 |
return;
|
| 30 |
}
|
| 31 |
|
| 32 |
+
ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 64/ncols2, ncols2>(ctx, dst);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
}
|
| 34 |
|
| 35 |
+
template <int DKQ, int DV>
|
| 36 |
+
static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 37 |
const ggml_tensor * KQV = dst;
|
| 38 |
const ggml_tensor * Q = dst->src[0];
|
| 39 |
const ggml_tensor * K = dst->src[1];
|
|
|
|
| 42 |
float max_bias = 0.0f;
|
| 43 |
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
|
| 44 |
|
| 45 |
+
const bool use_gqa_opt = mask && max_bias == 0.0f;
|
| 46 |
|
| 47 |
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
|
| 48 |
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
| 49 |
|
| 50 |
if (use_gqa_opt && gqa_ratio % 8 == 0) {
|
| 51 |
+
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 8>(ctx, dst);
|
| 52 |
return;
|
| 53 |
}
|
| 54 |
|
| 55 |
+
if (use_gqa_opt && gqa_ratio % 4 == 0) {
|
| 56 |
+
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 4>(ctx, dst);
|
| 57 |
return;
|
| 58 |
}
|
| 59 |
|
| 60 |
+
if (use_gqa_opt && gqa_ratio % 2 == 0) {
|
| 61 |
+
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 2>(ctx, dst);
|
| 62 |
return;
|
| 63 |
}
|
| 64 |
|
| 65 |
+
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 1>(ctx, dst);
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 69 |
+
const ggml_tensor * KQV = dst;
|
| 70 |
+
const ggml_tensor * Q = dst->src[0];
|
| 71 |
+
const ggml_tensor * K = dst->src[1];
|
| 72 |
+
const ggml_tensor * V = dst->src[2];
|
| 73 |
+
const ggml_tensor * mask = dst->src[3];
|
| 74 |
+
|
| 75 |
+
switch (Q->ne[0]) {
|
| 76 |
+
case 64:
|
| 77 |
+
GGML_ASSERT(V->ne[0] == 64);
|
| 78 |
+
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2< 64, 64>(ctx, dst);
|
| 79 |
+
break;
|
| 80 |
+
case 80:
|
| 81 |
+
GGML_ASSERT(V->ne[0] == 80);
|
| 82 |
+
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2< 80, 80>(ctx, dst);
|
| 83 |
+
break;
|
| 84 |
+
case 96:
|
| 85 |
+
GGML_ASSERT(V->ne[0] == 96);
|
| 86 |
+
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2< 96, 96>(ctx, dst);
|
| 87 |
+
break;
|
| 88 |
+
case 112:
|
| 89 |
+
GGML_ASSERT(V->ne[0] == 112);
|
| 90 |
+
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<112, 112>(ctx, dst);
|
| 91 |
+
break;
|
| 92 |
+
case 128:
|
| 93 |
+
GGML_ASSERT(V->ne[0] == 128);
|
| 94 |
+
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<128, 128>(ctx, dst);
|
| 95 |
+
break;
|
| 96 |
+
case 256:
|
| 97 |
+
GGML_ASSERT(V->ne[0] == 256);
|
| 98 |
+
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst);
|
| 99 |
+
break;
|
| 100 |
+
case 576: {
|
| 101 |
+
// For Deepseek, go straight to the ncols1 switch to avoid compiling unnecessary kernels.
|
| 102 |
+
GGML_ASSERT(V->ne[0] == 512);
|
| 103 |
+
float max_bias = 0.0f;
|
| 104 |
+
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
|
| 105 |
+
|
| 106 |
+
const bool use_gqa_opt = mask && max_bias == 0.0f;
|
| 107 |
+
GGML_ASSERT(use_gqa_opt);
|
| 108 |
+
|
| 109 |
+
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
|
| 110 |
+
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
| 111 |
+
GGML_ASSERT(gqa_ratio % 16 == 0);
|
| 112 |
+
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
|
| 113 |
+
} break;
|
| 114 |
+
default:
|
| 115 |
+
GGML_ABORT("fatal error");
|
| 116 |
+
break;
|
| 117 |
+
}
|
| 118 |
}
|
| 119 |
|
| 120 |
#define FATTN_VEC_F16_CASE(D, type_K, type_V) \
|
|
|
|
| 325 |
const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations
|
| 326 |
const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16;
|
| 327 |
const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && cc < GGML_CUDA_CC_ADA_LOVELACE && !mma_needs_data_conversion;
|
| 328 |
+
const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % (2*warp_size) == 0;
|
| 329 |
if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) {
|
| 330 |
if (prec == GGML_PREC_DEFAULT) {
|
| 331 |
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
ggml/src/ggml-cuda/ggml-cuda.cu
CHANGED
|
@@ -3215,16 +3215,16 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|
| 3215 |
return false;
|
| 3216 |
#endif // FLASH_ATTN_AVAILABLE
|
| 3217 |
if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
|
| 3218 |
-
|
| 3219 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3220 |
}
|
| 3221 |
if (op->src[0]->ne[0] == 192) {
|
| 3222 |
return false;
|
| 3223 |
}
|
| 3224 |
-
if (op->src[0]->ne[0] == 576) {
|
| 3225 |
-
// DeepSeek MLA
|
| 3226 |
-
return false;
|
| 3227 |
-
}
|
| 3228 |
if (op->src[0]->ne[3] != 1) {
|
| 3229 |
return false;
|
| 3230 |
}
|
|
|
|
| 3215 |
return false;
|
| 3216 |
#endif // FLASH_ATTN_AVAILABLE
|
| 3217 |
if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
|
| 3218 |
+
const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
|
| 3219 |
+
if (!new_mma_available(cc) || cc < GGML_CUDA_CC_AMPERE) {
|
| 3220 |
+
return false;
|
| 3221 |
+
}
|
| 3222 |
+
const int gqa_ratio = op->src[0]->ne[2] / op->src[1]->ne[2];
|
| 3223 |
+
return op->src[1]->ne[0] == 576 && op->src[2]->ne[0] == 512 && op->src[3] && gqa_ratio % 16 == 0;
|
| 3224 |
}
|
| 3225 |
if (op->src[0]->ne[0] == 192) {
|
| 3226 |
return false;
|
| 3227 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3228 |
if (op->src[0]->ne[3] != 1) {
|
| 3229 |
return false;
|
| 3230 |
}
|
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
| 2 |
+
|
| 3 |
+
#include "../fattn-mma-f16.cuh"
|
| 4 |
+
|
| 5 |
+
DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16);
|
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu
CHANGED
|
@@ -2,9 +2,9 @@
|
|
| 2 |
|
| 3 |
#include "../fattn-mma-f16.cuh"
|
| 4 |
|
| 5 |
-
DECL_FATTN_MMA_F16_CASE(64, 1, 8);
|
| 6 |
-
DECL_FATTN_MMA_F16_CASE(80, 1, 8);
|
| 7 |
-
DECL_FATTN_MMA_F16_CASE(96, 1, 8);
|
| 8 |
-
DECL_FATTN_MMA_F16_CASE(112, 1, 8);
|
| 9 |
-
DECL_FATTN_MMA_F16_CASE(128, 1, 8);
|
| 10 |
-
DECL_FATTN_MMA_F16_CASE(256, 1, 8);
|
|
|
|
| 2 |
|
| 3 |
#include "../fattn-mma-f16.cuh"
|
| 4 |
|
| 5 |
+
DECL_FATTN_MMA_F16_CASE(64, 64, 1, 8);
|
| 6 |
+
DECL_FATTN_MMA_F16_CASE(80, 80, 1, 8);
|
| 7 |
+
DECL_FATTN_MMA_F16_CASE(96, 96, 1, 8);
|
| 8 |
+
DECL_FATTN_MMA_F16_CASE(112, 112, 1, 8);
|
| 9 |
+
DECL_FATTN_MMA_F16_CASE(128, 128, 1, 8);
|
| 10 |
+
DECL_FATTN_MMA_F16_CASE(256, 256, 1, 8);
|
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu
CHANGED
|
@@ -2,9 +2,9 @@
|
|
| 2 |
|
| 3 |
#include "../fattn-mma-f16.cuh"
|
| 4 |
|
| 5 |
-
DECL_FATTN_MMA_F16_CASE(64, 16, 1);
|
| 6 |
-
DECL_FATTN_MMA_F16_CASE(80, 16, 1);
|
| 7 |
-
DECL_FATTN_MMA_F16_CASE(96, 16, 1);
|
| 8 |
-
DECL_FATTN_MMA_F16_CASE(112, 16, 1);
|
| 9 |
-
DECL_FATTN_MMA_F16_CASE(128, 16, 1);
|
| 10 |
-
DECL_FATTN_MMA_F16_CASE(256, 16, 1);
|
|
|
|
| 2 |
|
| 3 |
#include "../fattn-mma-f16.cuh"
|
| 4 |
|
| 5 |
+
DECL_FATTN_MMA_F16_CASE(64, 64, 16, 1);
|
| 6 |
+
DECL_FATTN_MMA_F16_CASE(80, 80, 16, 1);
|
| 7 |
+
DECL_FATTN_MMA_F16_CASE(96, 96, 16, 1);
|
| 8 |
+
DECL_FATTN_MMA_F16_CASE(112, 112, 16, 1);
|
| 9 |
+
DECL_FATTN_MMA_F16_CASE(128, 128, 16, 1);
|
| 10 |
+
DECL_FATTN_MMA_F16_CASE(256, 256, 16, 1);
|
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu
CHANGED
|
@@ -2,9 +2,9 @@
|
|
| 2 |
|
| 3 |
#include "../fattn-mma-f16.cuh"
|
| 4 |
|
| 5 |
-
DECL_FATTN_MMA_F16_CASE(64, 16, 2);
|
| 6 |
-
DECL_FATTN_MMA_F16_CASE(80, 16, 2);
|
| 7 |
-
DECL_FATTN_MMA_F16_CASE(96, 16, 2);
|
| 8 |
-
DECL_FATTN_MMA_F16_CASE(112, 16, 2);
|
| 9 |
-
DECL_FATTN_MMA_F16_CASE(128, 16, 2);
|
| 10 |
-
DECL_FATTN_MMA_F16_CASE(256, 16, 2);
|
|
|
|
| 2 |
|
| 3 |
#include "../fattn-mma-f16.cuh"
|
| 4 |
|
| 5 |
+
DECL_FATTN_MMA_F16_CASE(64, 64, 16, 2);
|
| 6 |
+
DECL_FATTN_MMA_F16_CASE(80, 80, 16, 2);
|
| 7 |
+
DECL_FATTN_MMA_F16_CASE(96, 96, 16, 2);
|
| 8 |
+
DECL_FATTN_MMA_F16_CASE(112, 112, 16, 2);
|
| 9 |
+
DECL_FATTN_MMA_F16_CASE(128, 128, 16, 2);
|
| 10 |
+
DECL_FATTN_MMA_F16_CASE(256, 256, 16, 2);
|
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu
CHANGED
|
@@ -2,9 +2,9 @@
|
|
| 2 |
|
| 3 |
#include "../fattn-mma-f16.cuh"
|
| 4 |
|
| 5 |
-
DECL_FATTN_MMA_F16_CASE(64, 16, 4);
|
| 6 |
-
DECL_FATTN_MMA_F16_CASE(80, 16, 4);
|
| 7 |
-
DECL_FATTN_MMA_F16_CASE(96, 16, 4);
|
| 8 |
-
DECL_FATTN_MMA_F16_CASE(112, 16, 4);
|
| 9 |
-
DECL_FATTN_MMA_F16_CASE(128, 16, 4);
|
| 10 |
-
DECL_FATTN_MMA_F16_CASE(256, 16, 4);
|
|
|
|
| 2 |
|
| 3 |
#include "../fattn-mma-f16.cuh"
|
| 4 |
|
| 5 |
+
DECL_FATTN_MMA_F16_CASE(64, 64, 16, 4);
|
| 6 |
+
DECL_FATTN_MMA_F16_CASE(80, 80, 16, 4);
|
| 7 |
+
DECL_FATTN_MMA_F16_CASE(96, 96, 16, 4);
|
| 8 |
+
DECL_FATTN_MMA_F16_CASE(112, 112, 16, 4);
|
| 9 |
+
DECL_FATTN_MMA_F16_CASE(128, 128, 16, 4);
|
| 10 |
+
DECL_FATTN_MMA_F16_CASE(256, 256, 16, 4);
|
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
| 2 |
+
|
| 3 |
+
#include "../fattn-mma-f16.cuh"
|
| 4 |
+
|
| 5 |
+
DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16);
|
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu
CHANGED
|
@@ -2,9 +2,9 @@
|
|
| 2 |
|
| 3 |
#include "../fattn-mma-f16.cuh"
|
| 4 |
|
| 5 |
-
DECL_FATTN_MMA_F16_CASE(64, 2, 4);
|
| 6 |
-
DECL_FATTN_MMA_F16_CASE(80, 2, 4);
|
| 7 |
-
DECL_FATTN_MMA_F16_CASE(96, 2, 4);
|
| 8 |
-
DECL_FATTN_MMA_F16_CASE(112, 2, 4);
|
| 9 |
-
DECL_FATTN_MMA_F16_CASE(128, 2, 4);
|
| 10 |
-
DECL_FATTN_MMA_F16_CASE(256, 2, 4);
|
|
|
|
| 2 |
|
| 3 |
#include "../fattn-mma-f16.cuh"
|
| 4 |
|
| 5 |
+
DECL_FATTN_MMA_F16_CASE(64, 64, 2, 4);
|
| 6 |
+
DECL_FATTN_MMA_F16_CASE(80, 80, 2, 4);
|
| 7 |
+
DECL_FATTN_MMA_F16_CASE(96, 96, 2, 4);
|
| 8 |
+
DECL_FATTN_MMA_F16_CASE(112, 112, 2, 4);
|
| 9 |
+
DECL_FATTN_MMA_F16_CASE(128, 128, 2, 4);
|
| 10 |
+
DECL_FATTN_MMA_F16_CASE(256, 256, 2, 4);
|
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu
CHANGED
|
@@ -2,9 +2,9 @@
|
|
| 2 |
|
| 3 |
#include "../fattn-mma-f16.cuh"
|
| 4 |
|
| 5 |
-
DECL_FATTN_MMA_F16_CASE(64, 2, 8);
|
| 6 |
-
DECL_FATTN_MMA_F16_CASE(80, 2, 8);
|
| 7 |
-
DECL_FATTN_MMA_F16_CASE(96, 2, 8);
|
| 8 |
-
DECL_FATTN_MMA_F16_CASE(112, 2, 8);
|
| 9 |
-
DECL_FATTN_MMA_F16_CASE(128, 2, 8);
|
| 10 |
-
DECL_FATTN_MMA_F16_CASE(256, 2, 8);
|
|
|
|
| 2 |
|
| 3 |
#include "../fattn-mma-f16.cuh"
|
| 4 |
|
| 5 |
+
DECL_FATTN_MMA_F16_CASE(64, 64, 2, 8);
|
| 6 |
+
DECL_FATTN_MMA_F16_CASE(80, 80, 2, 8);
|
| 7 |
+
DECL_FATTN_MMA_F16_CASE(96, 96, 2, 8);
|
| 8 |
+
DECL_FATTN_MMA_F16_CASE(112, 112, 2, 8);
|
| 9 |
+
DECL_FATTN_MMA_F16_CASE(128, 128, 2, 8);
|
| 10 |
+
DECL_FATTN_MMA_F16_CASE(256, 256, 2, 8);
|
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu
CHANGED
|
@@ -2,9 +2,9 @@
|
|
| 2 |
|
| 3 |
#include "../fattn-mma-f16.cuh"
|
| 4 |
|
| 5 |
-
DECL_FATTN_MMA_F16_CASE(64, 32, 1);
|
| 6 |
-
DECL_FATTN_MMA_F16_CASE(80, 32, 1);
|
| 7 |
-
DECL_FATTN_MMA_F16_CASE(96, 32, 1);
|
| 8 |
-
DECL_FATTN_MMA_F16_CASE(112, 32, 1);
|
| 9 |
-
DECL_FATTN_MMA_F16_CASE(128, 32, 1);
|
| 10 |
-
DECL_FATTN_MMA_F16_CASE(256, 32, 1);
|
|
|
|
| 2 |
|
| 3 |
#include "../fattn-mma-f16.cuh"
|
| 4 |
|
| 5 |
+
DECL_FATTN_MMA_F16_CASE(64, 64, 32, 1);
|
| 6 |
+
DECL_FATTN_MMA_F16_CASE(80, 80, 32, 1);
|
| 7 |
+
DECL_FATTN_MMA_F16_CASE(96, 96, 32, 1);
|
| 8 |
+
DECL_FATTN_MMA_F16_CASE(112, 112, 32, 1);
|
| 9 |
+
DECL_FATTN_MMA_F16_CASE(128, 128, 32, 1);
|
| 10 |
+
DECL_FATTN_MMA_F16_CASE(256, 256, 32, 1);
|
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu
CHANGED
|
@@ -2,9 +2,9 @@
|
|
| 2 |
|
| 3 |
#include "../fattn-mma-f16.cuh"
|
| 4 |
|
| 5 |
-
DECL_FATTN_MMA_F16_CASE(64, 32, 2);
|
| 6 |
-
DECL_FATTN_MMA_F16_CASE(80, 32, 2);
|
| 7 |
-
DECL_FATTN_MMA_F16_CASE(96, 32, 2);
|
| 8 |
-
DECL_FATTN_MMA_F16_CASE(112, 32, 2);
|
| 9 |
-
DECL_FATTN_MMA_F16_CASE(128, 32, 2);
|
| 10 |
-
DECL_FATTN_MMA_F16_CASE(256, 32, 2);
|
|
|
|
| 2 |
|
| 3 |
#include "../fattn-mma-f16.cuh"
|
| 4 |
|
| 5 |
+
DECL_FATTN_MMA_F16_CASE(64, 64, 32, 2);
|
| 6 |
+
DECL_FATTN_MMA_F16_CASE(80, 80, 32, 2);
|
| 7 |
+
DECL_FATTN_MMA_F16_CASE(96, 96, 32, 2);
|
| 8 |
+
DECL_FATTN_MMA_F16_CASE(112, 112, 32, 2);
|
| 9 |
+
DECL_FATTN_MMA_F16_CASE(128, 128, 32, 2);
|
| 10 |
+
DECL_FATTN_MMA_F16_CASE(256, 256, 32, 2);
|
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
| 2 |
+
|
| 3 |
+
#include "../fattn-mma-f16.cuh"
|
| 4 |
+
|
| 5 |
+
DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16);
|
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu
CHANGED
|
@@ -2,9 +2,9 @@
|
|
| 2 |
|
| 3 |
#include "../fattn-mma-f16.cuh"
|
| 4 |
|
| 5 |
-
DECL_FATTN_MMA_F16_CASE(64, 4, 2);
|
| 6 |
-
DECL_FATTN_MMA_F16_CASE(80, 4, 2);
|
| 7 |
-
DECL_FATTN_MMA_F16_CASE(96, 4, 2);
|
| 8 |
-
DECL_FATTN_MMA_F16_CASE(112, 4, 2);
|
| 9 |
-
DECL_FATTN_MMA_F16_CASE(128, 4, 2);
|
| 10 |
-
DECL_FATTN_MMA_F16_CASE(256, 4, 2);
|
|
|
|
| 2 |
|
| 3 |
#include "../fattn-mma-f16.cuh"
|
| 4 |
|
| 5 |
+
DECL_FATTN_MMA_F16_CASE(64, 64, 4, 2);
|
| 6 |
+
DECL_FATTN_MMA_F16_CASE(80, 80, 4, 2);
|
| 7 |
+
DECL_FATTN_MMA_F16_CASE(96, 96, 4, 2);
|
| 8 |
+
DECL_FATTN_MMA_F16_CASE(112, 112, 4, 2);
|
| 9 |
+
DECL_FATTN_MMA_F16_CASE(128, 128, 4, 2);
|
| 10 |
+
DECL_FATTN_MMA_F16_CASE(256, 256, 4, 2);
|
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu
CHANGED
|
@@ -2,9 +2,9 @@
|
|
| 2 |
|
| 3 |
#include "../fattn-mma-f16.cuh"
|
| 4 |
|
| 5 |
-
DECL_FATTN_MMA_F16_CASE(64, 4, 4);
|
| 6 |
-
DECL_FATTN_MMA_F16_CASE(80, 4, 4);
|
| 7 |
-
DECL_FATTN_MMA_F16_CASE(96, 4, 4);
|
| 8 |
-
DECL_FATTN_MMA_F16_CASE(112, 4, 4);
|
| 9 |
-
DECL_FATTN_MMA_F16_CASE(128, 4, 4);
|
| 10 |
-
DECL_FATTN_MMA_F16_CASE(256, 4, 4);
|
|
|
|
| 2 |
|
| 3 |
#include "../fattn-mma-f16.cuh"
|
| 4 |
|
| 5 |
+
DECL_FATTN_MMA_F16_CASE(64, 64, 4, 4);
|
| 6 |
+
DECL_FATTN_MMA_F16_CASE(80, 80, 4, 4);
|
| 7 |
+
DECL_FATTN_MMA_F16_CASE(96, 96, 4, 4);
|
| 8 |
+
DECL_FATTN_MMA_F16_CASE(112, 112, 4, 4);
|
| 9 |
+
DECL_FATTN_MMA_F16_CASE(128, 128, 4, 4);
|
| 10 |
+
DECL_FATTN_MMA_F16_CASE(256, 256, 4, 4);
|
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu
CHANGED
|
@@ -2,9 +2,9 @@
|
|
| 2 |
|
| 3 |
#include "../fattn-mma-f16.cuh"
|
| 4 |
|
| 5 |
-
DECL_FATTN_MMA_F16_CASE(64, 4, 8);
|
| 6 |
-
DECL_FATTN_MMA_F16_CASE(80, 4, 8);
|
| 7 |
-
DECL_FATTN_MMA_F16_CASE(96, 4, 8);
|
| 8 |
-
DECL_FATTN_MMA_F16_CASE(112, 4, 8);
|
| 9 |
-
DECL_FATTN_MMA_F16_CASE(128, 4, 8);
|
| 10 |
-
DECL_FATTN_MMA_F16_CASE(256, 4, 8);
|
|
|
|
| 2 |
|
| 3 |
#include "../fattn-mma-f16.cuh"
|
| 4 |
|
| 5 |
+
DECL_FATTN_MMA_F16_CASE(64, 64, 4, 8);
|
| 6 |
+
DECL_FATTN_MMA_F16_CASE(80, 80, 4, 8);
|
| 7 |
+
DECL_FATTN_MMA_F16_CASE(96, 96, 4, 8);
|
| 8 |
+
DECL_FATTN_MMA_F16_CASE(112, 112, 4, 8);
|
| 9 |
+
DECL_FATTN_MMA_F16_CASE(128, 128, 4, 8);
|
| 10 |
+
DECL_FATTN_MMA_F16_CASE(256, 256, 4, 8);
|
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu
CHANGED
|
@@ -2,9 +2,9 @@
|
|
| 2 |
|
| 3 |
#include "../fattn-mma-f16.cuh"
|
| 4 |
|
| 5 |
-
DECL_FATTN_MMA_F16_CASE(64, 64, 1);
|
| 6 |
-
DECL_FATTN_MMA_F16_CASE(80, 64, 1);
|
| 7 |
-
DECL_FATTN_MMA_F16_CASE(96, 64, 1);
|
| 8 |
-
DECL_FATTN_MMA_F16_CASE(112, 64, 1);
|
| 9 |
-
DECL_FATTN_MMA_F16_CASE(128, 64, 1);
|
| 10 |
-
DECL_FATTN_MMA_F16_CASE(256, 64, 1);
|
|
|
|
| 2 |
|
| 3 |
#include "../fattn-mma-f16.cuh"
|
| 4 |
|
| 5 |
+
DECL_FATTN_MMA_F16_CASE(64, 64, 64, 1);
|
| 6 |
+
DECL_FATTN_MMA_F16_CASE(80, 80, 64, 1);
|
| 7 |
+
DECL_FATTN_MMA_F16_CASE(96, 96, 64, 1);
|
| 8 |
+
DECL_FATTN_MMA_F16_CASE(112, 112, 64, 1);
|
| 9 |
+
DECL_FATTN_MMA_F16_CASE(128, 128, 64, 1);
|
| 10 |
+
DECL_FATTN_MMA_F16_CASE(256, 256, 64, 1);
|
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu
CHANGED
|
@@ -2,9 +2,9 @@
|
|
| 2 |
|
| 3 |
#include "../fattn-mma-f16.cuh"
|
| 4 |
|
| 5 |
-
DECL_FATTN_MMA_F16_CASE(64, 8, 1);
|
| 6 |
-
DECL_FATTN_MMA_F16_CASE(80, 8, 1);
|
| 7 |
-
DECL_FATTN_MMA_F16_CASE(96, 8, 1);
|
| 8 |
-
DECL_FATTN_MMA_F16_CASE(112, 8, 1);
|
| 9 |
-
DECL_FATTN_MMA_F16_CASE(128, 8, 1);
|
| 10 |
-
DECL_FATTN_MMA_F16_CASE(256, 8, 1);
|
|
|
|
| 2 |
|
| 3 |
#include "../fattn-mma-f16.cuh"
|
| 4 |
|
| 5 |
+
DECL_FATTN_MMA_F16_CASE(64, 64, 8, 1);
|
| 6 |
+
DECL_FATTN_MMA_F16_CASE(80, 80, 8, 1);
|
| 7 |
+
DECL_FATTN_MMA_F16_CASE(96, 96, 8, 1);
|
| 8 |
+
DECL_FATTN_MMA_F16_CASE(112, 112, 8, 1);
|
| 9 |
+
DECL_FATTN_MMA_F16_CASE(128, 128, 8, 1);
|
| 10 |
+
DECL_FATTN_MMA_F16_CASE(256, 256, 8, 1);
|
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu
CHANGED
|
@@ -2,9 +2,9 @@
|
|
| 2 |
|
| 3 |
#include "../fattn-mma-f16.cuh"
|
| 4 |
|
| 5 |
-
DECL_FATTN_MMA_F16_CASE(64, 8, 2);
|
| 6 |
-
DECL_FATTN_MMA_F16_CASE(80, 8, 2);
|
| 7 |
-
DECL_FATTN_MMA_F16_CASE(96, 8, 2);
|
| 8 |
-
DECL_FATTN_MMA_F16_CASE(112, 8, 2);
|
| 9 |
-
DECL_FATTN_MMA_F16_CASE(128, 8, 2);
|
| 10 |
-
DECL_FATTN_MMA_F16_CASE(256, 8, 2);
|
|
|
|
| 2 |
|
| 3 |
#include "../fattn-mma-f16.cuh"
|
| 4 |
|
| 5 |
+
DECL_FATTN_MMA_F16_CASE(64, 64, 8, 2);
|
| 6 |
+
DECL_FATTN_MMA_F16_CASE(80, 80, 8, 2);
|
| 7 |
+
DECL_FATTN_MMA_F16_CASE(96, 96, 8, 2);
|
| 8 |
+
DECL_FATTN_MMA_F16_CASE(112, 112, 8, 2);
|
| 9 |
+
DECL_FATTN_MMA_F16_CASE(128, 128, 8, 2);
|
| 10 |
+
DECL_FATTN_MMA_F16_CASE(256, 256, 8, 2);
|
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu
CHANGED
|
@@ -2,9 +2,9 @@
|
|
| 2 |
|
| 3 |
#include "../fattn-mma-f16.cuh"
|
| 4 |
|
| 5 |
-
DECL_FATTN_MMA_F16_CASE(64, 8, 4);
|
| 6 |
-
DECL_FATTN_MMA_F16_CASE(80, 8, 4);
|
| 7 |
-
DECL_FATTN_MMA_F16_CASE(96, 8, 4);
|
| 8 |
-
DECL_FATTN_MMA_F16_CASE(112, 8, 4);
|
| 9 |
-
DECL_FATTN_MMA_F16_CASE(128, 8, 4);
|
| 10 |
-
DECL_FATTN_MMA_F16_CASE(256, 8, 4);
|
|
|
|
| 2 |
|
| 3 |
#include "../fattn-mma-f16.cuh"
|
| 4 |
|
| 5 |
+
DECL_FATTN_MMA_F16_CASE(64, 64, 8, 4);
|
| 6 |
+
DECL_FATTN_MMA_F16_CASE(80, 80, 8, 4);
|
| 7 |
+
DECL_FATTN_MMA_F16_CASE(96, 96, 8, 4);
|
| 8 |
+
DECL_FATTN_MMA_F16_CASE(112, 112, 8, 4);
|
| 9 |
+
DECL_FATTN_MMA_F16_CASE(128, 128, 8, 4);
|
| 10 |
+
DECL_FATTN_MMA_F16_CASE(256, 256, 8, 4);
|
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu
CHANGED
|
@@ -2,9 +2,9 @@
|
|
| 2 |
|
| 3 |
#include "../fattn-mma-f16.cuh"
|
| 4 |
|
| 5 |
-
DECL_FATTN_MMA_F16_CASE(64, 8, 8);
|
| 6 |
-
DECL_FATTN_MMA_F16_CASE(80, 8, 8);
|
| 7 |
-
DECL_FATTN_MMA_F16_CASE(96, 8, 8);
|
| 8 |
-
DECL_FATTN_MMA_F16_CASE(112, 8, 8);
|
| 9 |
-
DECL_FATTN_MMA_F16_CASE(128, 8, 8);
|
| 10 |
-
DECL_FATTN_MMA_F16_CASE(256, 8, 8);
|
|
|
|
| 2 |
|
| 3 |
#include "../fattn-mma-f16.cuh"
|
| 4 |
|
| 5 |
+
DECL_FATTN_MMA_F16_CASE(64, 64, 8, 8);
|
| 6 |
+
DECL_FATTN_MMA_F16_CASE(80, 80, 8, 8);
|
| 7 |
+
DECL_FATTN_MMA_F16_CASE(96, 96, 8, 8);
|
| 8 |
+
DECL_FATTN_MMA_F16_CASE(112, 112, 8, 8);
|
| 9 |
+
DECL_FATTN_MMA_F16_CASE(128, 128, 8, 8);
|
| 10 |
+
DECL_FATTN_MMA_F16_CASE(256, 256, 8, 8);
|
ggml/src/ggml-cuda/template-instances/generate_cu_files.py
CHANGED
|
@@ -18,7 +18,7 @@ SOURCE_FATTN_MMA_START = """// This file has been autogenerated by generate_cu_f
|
|
| 18 |
|
| 19 |
"""
|
| 20 |
|
| 21 |
-
SOURCE_FATTN_MMA_CASE = "DECL_FATTN_MMA_F16_CASE({
|
| 22 |
|
| 23 |
TYPES_MMQ = [
|
| 24 |
"GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0",
|
|
@@ -57,18 +57,21 @@ for vkq_size in [16, 32]:
|
|
| 57 |
with open(f"fattn-vec-f{vkq_size}-instance-hs{head_size}-{get_short_name(type_k)}-{get_short_name(type_v)}.cu", "w") as f:
|
| 58 |
f.write(SOURCE_FATTN_VEC.format(vkq_size=vkq_size, head_size=head_size, type_k=type_k, type_v=type_v))
|
| 59 |
|
| 60 |
-
for ncols in [8, 16, 32, 64
|
| 61 |
-
for ncols2 in [1, 2, 4, 8]:
|
|
|
|
|
|
|
| 62 |
ncols1 = ncols // ncols2
|
| 63 |
-
if ncols == 128:
|
| 64 |
-
continue # Too much register pressure.
|
| 65 |
with open(f"fattn-mma-f16-instance-ncols1_{ncols1}-ncols2_{ncols2}.cu", "w") as f:
|
| 66 |
f.write(SOURCE_FATTN_MMA_START)
|
| 67 |
|
| 68 |
-
for
|
| 69 |
-
if
|
| 70 |
-
continue
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
| 72 |
|
| 73 |
for type in TYPES_MMQ:
|
| 74 |
with open(f"mmq-instance-{get_short_name(type)}.cu", "w") as f:
|
|
|
|
| 18 |
|
| 19 |
"""
|
| 20 |
|
| 21 |
+
SOURCE_FATTN_MMA_CASE = "DECL_FATTN_MMA_F16_CASE({head_size_kq}, {head_size_v}, {ncols1}, {ncols2});\n"
|
| 22 |
|
| 23 |
TYPES_MMQ = [
|
| 24 |
"GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0",
|
|
|
|
| 57 |
with open(f"fattn-vec-f{vkq_size}-instance-hs{head_size}-{get_short_name(type_k)}-{get_short_name(type_v)}.cu", "w") as f:
|
| 58 |
f.write(SOURCE_FATTN_VEC.format(vkq_size=vkq_size, head_size=head_size, type_k=type_k, type_v=type_v))
|
| 59 |
|
| 60 |
+
for ncols in [8, 16, 32, 64]:
|
| 61 |
+
for ncols2 in [1, 2, 4, 8, 16]:
|
| 62 |
+
if ncols2 > ncols:
|
| 63 |
+
continue
|
| 64 |
ncols1 = ncols // ncols2
|
|
|
|
|
|
|
| 65 |
with open(f"fattn-mma-f16-instance-ncols1_{ncols1}-ncols2_{ncols2}.cu", "w") as f:
|
| 66 |
f.write(SOURCE_FATTN_MMA_START)
|
| 67 |
|
| 68 |
+
for head_size_kq in [64, 80, 96, 112, 128, 256, 576]:
|
| 69 |
+
if head_size_kq != 576 and ncols2 == 16:
|
| 70 |
+
continue
|
| 71 |
+
if head_size_kq == 576 and ncols2 != 16:
|
| 72 |
+
continue
|
| 73 |
+
head_size_v = head_size_kq if head_size_kq != 576 else 512
|
| 74 |
+
f.write(SOURCE_FATTN_MMA_CASE.format(ncols1=ncols1, ncols2=ncols2, head_size_kq=head_size_kq, head_size_v=head_size_v))
|
| 75 |
|
| 76 |
for type in TYPES_MMQ:
|
| 77 |
with open(f"mmq-instance-{get_short_name(type)}.cu", "w") as f:
|