Spaces:
Running
Running
Commit
·
5b9980d
1
Parent(s):
3c2171d
CUDA: use async data loading for FlashAttention (llama/11894)
Browse files* CUDA: use async data loading for FlashAttention
---------
Co-authored-by: Diego Devesa <[email protected]>
- ggml/src/ggml-cuda/common.cuh +15 -6
- ggml/src/ggml-cuda/cp-async.cuh +46 -0
- ggml/src/ggml-cuda/fattn-common.cuh +9 -6
- ggml/src/ggml-cuda/fattn-mma-f16.cuh +342 -258
- ggml/src/ggml-cuda/mma.cuh +172 -311
- ggml/src/ggml-cuda/mmq.cuh +140 -138
ggml/src/ggml-cuda/common.cuh
CHANGED
|
@@ -41,12 +41,13 @@
|
|
| 41 |
#define CUDART_HMAX 11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed)
|
| 42 |
#define CUDART_HMASK 12000 // CUDA 12.0, min. ver. for half2 -> uint mask comparisons
|
| 43 |
|
| 44 |
-
#define GGML_CUDA_CC_PASCAL
|
| 45 |
-
#define GGML_CUDA_CC_DP4A
|
| 46 |
-
#define GGML_CUDA_CC_VOLTA
|
| 47 |
-
#define GGML_CUDA_CC_TURING
|
| 48 |
-
#define GGML_CUDA_CC_AMPERE
|
| 49 |
-
#define
|
|
|
|
| 50 |
|
| 51 |
// GCN/CNDA, wave size is 64
|
| 52 |
#define GGML_CUDA_CC_GCN4 (GGML_CUDA_CC_OFFSET_AMD + 0x803) // Tonga, Fiji, Polaris, minimum for fast fp16
|
|
@@ -199,6 +200,10 @@ typedef float2 dfloat2;
|
|
| 199 |
#define NEW_MMA_AVAILABLE
|
| 200 |
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
|
| 201 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
#if !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
|
| 203 |
#define FLASH_ATTN_AVAILABLE
|
| 204 |
#endif // !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
|
|
@@ -231,6 +236,10 @@ static bool new_mma_available(const int cc) {
|
|
| 231 |
return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
|
| 232 |
}
|
| 233 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 234 |
static constexpr __device__ int ggml_cuda_get_physical_warp_size() {
|
| 235 |
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
| 236 |
return __AMDGCN_WAVEFRONT_SIZE;
|
|
|
|
| 41 |
#define CUDART_HMAX 11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed)
|
| 42 |
#define CUDART_HMASK 12000 // CUDA 12.0, min. ver. for half2 -> uint mask comparisons
|
| 43 |
|
| 44 |
+
#define GGML_CUDA_CC_PASCAL 600
|
| 45 |
+
#define GGML_CUDA_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
|
| 46 |
+
#define GGML_CUDA_CC_VOLTA 700
|
| 47 |
+
#define GGML_CUDA_CC_TURING 750
|
| 48 |
+
#define GGML_CUDA_CC_AMPERE 800
|
| 49 |
+
#define GGML_CUDA_CC_ADA_LOVELACE 890
|
| 50 |
+
#define GGML_CUDA_CC_OFFSET_AMD 0x1000000
|
| 51 |
|
| 52 |
// GCN/CNDA, wave size is 64
|
| 53 |
#define GGML_CUDA_CC_GCN4 (GGML_CUDA_CC_OFFSET_AMD + 0x803) // Tonga, Fiji, Polaris, minimum for fast fp16
|
|
|
|
| 200 |
#define NEW_MMA_AVAILABLE
|
| 201 |
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
|
| 202 |
|
| 203 |
+
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
| 204 |
+
#define CP_ASYNC_AVAILABLE
|
| 205 |
+
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
| 206 |
+
|
| 207 |
#if !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
|
| 208 |
#define FLASH_ATTN_AVAILABLE
|
| 209 |
#endif // !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
|
|
|
|
| 236 |
return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
|
| 237 |
}
|
| 238 |
|
| 239 |
+
static bool cp_async_available(const int cc) {
|
| 240 |
+
return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE;
|
| 241 |
+
}
|
| 242 |
+
|
| 243 |
static constexpr __device__ int ggml_cuda_get_physical_warp_size() {
|
| 244 |
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
| 245 |
return __AMDGCN_WAVEFRONT_SIZE;
|
ggml/src/ggml-cuda/cp-async.cuh
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Simplified API for asynchronous data loading.
|
| 2 |
+
|
| 3 |
+
#include "common.cuh"
|
| 4 |
+
|
| 5 |
+
// Copies data from global to shared memory, cg == cache global.
|
| 6 |
+
// Both the src and dst pointers must be aligned to 16 bit.
|
| 7 |
+
// Shared memory uses 32 bit addressing, the pointer is passed as unsigned int.
|
| 8 |
+
// Generic pointers can be converted to 32 bit shared memory pointers using __cvta_generic_to_shared.
|
| 9 |
+
// Only the 16 bit copy is exposed because 4 and 8 bit copies did not yield performance improvements.
|
| 10 |
+
template <int preload>
|
| 11 |
+
static __device__ __forceinline__ void cp_async_cg_16(const unsigned int dst, const void * src) {
|
| 12 |
+
static_assert(preload == 0 || preload == 64 || preload == 128 || preload == 256, "bad preload");
|
| 13 |
+
#ifdef CP_ASYNC_AVAILABLE
|
| 14 |
+
#if CUDART_VERSION >= 11040
|
| 15 |
+
if (preload == 256) {
|
| 16 |
+
asm volatile("cp.async.cg.shared.global.L2::256B [%0], [%1], 16;"
|
| 17 |
+
: : "r"(dst), "l"(src));
|
| 18 |
+
} else if (preload == 128) {
|
| 19 |
+
asm volatile("cp.async.cg.shared.global.L2::128B [%0], [%1], 16;"
|
| 20 |
+
: : "r"(dst), "l"(src));
|
| 21 |
+
} else if (preload == 64) {
|
| 22 |
+
asm volatile("cp.async.cg.shared.global.L2::64B [%0], [%1], 16;"
|
| 23 |
+
: : "r"(dst), "l"(src));
|
| 24 |
+
} else
|
| 25 |
+
#endif // CUDART_VERSION >= 11040
|
| 26 |
+
{
|
| 27 |
+
asm volatile("cp.async.cg.shared.global.L2 [%0], [%1], 16;"
|
| 28 |
+
: : "r"(dst), "l"(src));
|
| 29 |
+
}
|
| 30 |
+
#else
|
| 31 |
+
GGML_UNUSED(dst);
|
| 32 |
+
GGML_UNUSED(src);
|
| 33 |
+
NO_DEVICE_CODE;
|
| 34 |
+
#endif // CP_ASYNC_AVAILABLE
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
// Makes each thread wait until its asynchronous data copies are done.
|
| 38 |
+
// This does NOT provide any additional synchronization.
|
| 39 |
+
// In particular, when copying data with multiple warps a call to __syncthreads will be needed.
|
| 40 |
+
static __device__ __forceinline__ void cp_async_wait_all() {
|
| 41 |
+
#ifdef CP_ASYNC_AVAILABLE
|
| 42 |
+
asm volatile("cp.async.wait_all;");
|
| 43 |
+
#else
|
| 44 |
+
NO_DEVICE_CODE;
|
| 45 |
+
#endif // CP_ASYNC_AVAILABLE
|
| 46 |
+
}
|
ggml/src/ggml-cuda/fattn-common.cuh
CHANGED
|
@@ -716,7 +716,9 @@ void launch_fattn(
|
|
| 716 |
|
| 717 |
ggml_cuda_pool & pool = ctx.pool();
|
| 718 |
cudaStream_t main_stream = ctx.stream();
|
| 719 |
-
const int
|
|
|
|
|
|
|
| 720 |
|
| 721 |
ggml_cuda_pool_alloc<half> K_f16(pool);
|
| 722 |
ggml_cuda_pool_alloc<half> V_f16(pool);
|
|
@@ -768,13 +770,14 @@ void launch_fattn(
|
|
| 768 |
dim3 blocks_num;
|
| 769 |
if (parallel_blocks == 0) {
|
| 770 |
// For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
|
| 771 |
-
const int tiles_nwaves = (ntiles_total
|
| 772 |
-
const
|
| 773 |
-
const bool short_context = K->ne[1] < 4096;
|
| 774 |
|
| 775 |
const int nblocks_stream_k = 2*nsm;
|
| 776 |
|
| 777 |
-
|
|
|
|
|
|
|
| 778 |
blocks_num.y = 1;
|
| 779 |
blocks_num.z = 1;
|
| 780 |
|
|
@@ -827,7 +830,7 @@ void launch_fattn(
|
|
| 827 |
CUDA_CHECK(cudaGetLastError());
|
| 828 |
|
| 829 |
if constexpr (parallel_blocks == 0) {
|
| 830 |
-
if (blocks_num.x
|
| 831 |
const dim3 block_dim_combine(D, 1, 1);
|
| 832 |
const dim3 blocks_num_combine = blocks_num;
|
| 833 |
|
|
|
|
| 716 |
|
| 717 |
ggml_cuda_pool & pool = ctx.pool();
|
| 718 |
cudaStream_t main_stream = ctx.stream();
|
| 719 |
+
const int id = ggml_cuda_get_device();
|
| 720 |
+
const int cc = ggml_cuda_info().devices[id].cc;
|
| 721 |
+
const int nsm = ggml_cuda_info().devices[id].nsm;
|
| 722 |
|
| 723 |
ggml_cuda_pool_alloc<half> K_f16(pool);
|
| 724 |
ggml_cuda_pool_alloc<half> V_f16(pool);
|
|
|
|
| 770 |
dim3 blocks_num;
|
| 771 |
if (parallel_blocks == 0) {
|
| 772 |
// For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
|
| 773 |
+
const int tiles_nwaves = (ntiles_total + 2*nsm - 1) / (2*nsm);
|
| 774 |
+
const int tiles_efficiency_percent = 100 * ntiles_total / (2*nsm*tiles_nwaves);
|
|
|
|
| 775 |
|
| 776 |
const int nblocks_stream_k = 2*nsm;
|
| 777 |
|
| 778 |
+
const bool use_stream_k = tiles_efficiency_percent < 75 || cc >= GGML_CUDA_CC_ADA_LOVELACE;
|
| 779 |
+
|
| 780 |
+
blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_total;
|
| 781 |
blocks_num.y = 1;
|
| 782 |
blocks_num.z = 1;
|
| 783 |
|
|
|
|
| 830 |
CUDA_CHECK(cudaGetLastError());
|
| 831 |
|
| 832 |
if constexpr (parallel_blocks == 0) {
|
| 833 |
+
if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
|
| 834 |
const dim3 block_dim_combine(D, 1, 1);
|
| 835 |
const dim3 blocks_num_combine = blocks_num;
|
| 836 |
|
ggml/src/ggml-cuda/fattn-mma-f16.cuh
CHANGED
|
@@ -1,242 +1,195 @@
|
|
| 1 |
#include "common.cuh"
|
|
|
|
| 2 |
#include "mma.cuh"
|
| 3 |
#include "fattn-common.cuh"
|
| 4 |
|
| 5 |
-
|
| 6 |
-
static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
| 7 |
-
const float2 * const __restrict__ Q_f2,
|
| 8 |
-
const half2 * const __restrict__ K_h2,
|
| 9 |
-
const half2 * const __restrict__ V_h2,
|
| 10 |
-
const half * const __restrict__ maskh,
|
| 11 |
-
float2 * const __restrict__ dstk,
|
| 12 |
-
float2 * const __restrict__ dstk_fixup,
|
| 13 |
-
const float scale,
|
| 14 |
-
const float slope,
|
| 15 |
-
const float logit_softcap,
|
| 16 |
-
const int ne00,
|
| 17 |
-
const int ne01,
|
| 18 |
-
const int ne02,
|
| 19 |
-
const int ne03,
|
| 20 |
-
const int ne10,
|
| 21 |
-
const int ne11,
|
| 22 |
-
const int ne12,
|
| 23 |
-
const int ne13,
|
| 24 |
-
const int ne31,
|
| 25 |
-
const int nb31,
|
| 26 |
-
const int nb01,
|
| 27 |
-
const int nb02,
|
| 28 |
-
const int nb03,
|
| 29 |
-
const int nb11,
|
| 30 |
-
const int nb12,
|
| 31 |
-
const int nb13,
|
| 32 |
-
const int nb21,
|
| 33 |
-
const int nb22,
|
| 34 |
-
const int nb23,
|
| 35 |
-
const int ne0,
|
| 36 |
-
const int ne1,
|
| 37 |
-
const int ne2,
|
| 38 |
-
const int ne3,
|
| 39 |
-
const int jt,
|
| 40 |
-
const int kb0_start,
|
| 41 |
-
const int kb0_stop) {
|
| 42 |
-
#ifdef NEW_MMA_AVAILABLE
|
| 43 |
-
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
| 44 |
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
static_assert(nwarps*mma_B::J % ncols == 0, "bad nwarps");
|
| 51 |
-
constexpr int np = nwarps*mma_B::J / ncols; // Number of parallel CUDA warps per Q column.
|
| 52 |
-
|
| 53 |
-
static_assert(D % nwarps == 0, "bad D");
|
| 54 |
-
static_assert(KQ_stride % nwarps == 0, "bad KQ_stride");
|
| 55 |
|
|
|
|
|
|
|
|
|
|
| 56 |
constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts.
|
| 57 |
-
extern __shared__ half2 tile_KV[]; // Temporary shared buffer for loading K/V data with KQ_stride*D logical elements.
|
| 58 |
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
|
|
|
| 62 |
|
| 63 |
-
|
| 64 |
-
mma_C_VKQ VKQ_C[D/mma_C_VKQ::I];
|
| 65 |
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
#pragma unroll
|
| 74 |
for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
|
| 75 |
-
const int k0_start = stride_k == WARP_SIZE ?
|
| 76 |
-
const int k0_stop =
|
| 77 |
-
const int
|
| 78 |
|
| 79 |
-
if (
|
| 80 |
-
|
| 81 |
}
|
| 82 |
|
| 83 |
#pragma unroll
|
| 84 |
-
for (int
|
| 85 |
-
const int
|
| 86 |
|
| 87 |
-
if (jt*ncols + j < ne01) {
|
| 88 |
-
#pragma unroll
|
| 89 |
-
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
|
| 90 |
-
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
|
| 91 |
-
|
| 92 |
-
const float2 tmp = Q_f2[(jt*ncols + j)*stride_Q + k];
|
| 93 |
-
tile_KV[j*D2_padded + k] = scale_h2 * make_half2(tmp.x, tmp.y);
|
| 94 |
-
}
|
| 95 |
-
} else {
|
| 96 |
#pragma unroll
|
| 97 |
-
|
| 98 |
-
|
| 99 |
|
| 100 |
-
|
| 101 |
-
}
|
| 102 |
}
|
| 103 |
}
|
| 104 |
}
|
|
|
|
| 105 |
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
const
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
Q_B[k0/mma_B::K].load_ldmatrix(tile_KV + j0*D2_padded + k0, D2_padded);
|
| 114 |
-
}
|
| 115 |
-
}
|
| 116 |
|
|
|
|
|
|
|
| 117 |
__syncthreads();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
|
| 119 |
-
//
|
| 120 |
-
for (int kb0 = kb0_start; kb0 < kb0_stop; ++kb0) {
|
| 121 |
-
const int k_VKQ_0 = kb0*KQ_stride;
|
| 122 |
-
mma_C_KQ KQ_C[KQ_stride/(np*mma_C_KQ::I)];
|
| 123 |
-
|
| 124 |
-
// Load K data into tile with decreasing granularity for D for better memory bandwidth:
|
| 125 |
-
static_assert(KQ_stride % (4*nwarps) == 0, "out of bounds");
|
| 126 |
-
#pragma unroll
|
| 127 |
-
for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
|
| 128 |
-
const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k);
|
| 129 |
-
const int k0_stop = D/2 - (D/2) % (1*stride_k);
|
| 130 |
-
const int stride_i = WARP_SIZE / stride_k;
|
| 131 |
-
|
| 132 |
-
#pragma unroll
|
| 133 |
-
for (int i_KQ_0 = 0; i_KQ_0 < KQ_stride; i_KQ_0 += nwarps*stride_i) {
|
| 134 |
-
const int i_KQ = i_KQ_0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
|
| 135 |
-
|
| 136 |
-
#pragma unroll
|
| 137 |
-
for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += stride_k) {
|
| 138 |
-
const int k_KQ = k_KQ_0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
|
| 139 |
-
|
| 140 |
-
tile_KV[i_KQ*D2_padded + k_KQ] = K_h2[(k_VKQ_0 + i_KQ)*stride_KV + k_KQ];
|
| 141 |
-
}
|
| 142 |
-
}
|
| 143 |
-
}
|
| 144 |
-
|
| 145 |
-
__syncthreads();
|
| 146 |
-
|
| 147 |
-
// Calculate tile of KQ:
|
| 148 |
#pragma unroll
|
| 149 |
-
|
| 150 |
-
|
| 151 |
#pragma unroll
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
}
|
| 157 |
}
|
|
|
|
| 158 |
|
| 159 |
-
|
|
|
|
|
|
|
| 160 |
|
| 161 |
-
|
| 162 |
-
|
| 163 |
#pragma unroll
|
| 164 |
-
|
| 165 |
#pragma unroll
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
}
|
| 169 |
}
|
| 170 |
}
|
|
|
|
| 171 |
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
#pragma unroll
|
| 176 |
-
|
| 177 |
-
|
| 178 |
#pragma unroll
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
|
| 183 |
-
|
| 184 |
-
}
|
| 185 |
}
|
| 186 |
}
|
|
|
|
| 187 |
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
#pragma unroll
|
| 193 |
-
|
| 194 |
#pragma unroll
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
}
|
| 199 |
}
|
|
|
|
| 200 |
|
| 201 |
-
|
| 202 |
#pragma unroll
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
{
|
| 209 |
-
const float2 diff = make_float2(KQ_max.x - KQ_max_new.x, KQ_max.y - KQ_max_new.y);
|
| 210 |
-
KQ_max_scale = make_float2(expf(diff.x), expf(diff.y));
|
| 211 |
-
if (diff.x <= SOFTMAX_FTZ_THRESHOLD) {
|
| 212 |
-
KQ_max_scale.x = 0.0f;
|
| 213 |
-
}
|
| 214 |
-
if (diff.y <= SOFTMAX_FTZ_THRESHOLD) {
|
| 215 |
-
KQ_max_scale.y = 0.0f;
|
| 216 |
-
}
|
| 217 |
-
KQ_max = KQ_max_new;
|
| 218 |
-
}
|
| 219 |
|
| 220 |
-
|
| 221 |
-
|
| 222 |
#pragma unroll
|
| 223 |
-
|
| 224 |
#pragma unroll
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
if (diff <= SOFTMAX_FTZ_THRESHOLD) {
|
| 230 |
-
KQ_C[k].x[l] = 0.0f;
|
| 231 |
-
}
|
| 232 |
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
}
|
| 238 |
}
|
| 239 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 240 |
|
| 241 |
// Scale previous KQ_rowsum to account for a potential increase in KQ_max:
|
| 242 |
KQ_rowsum.x = KQ_max_scale.x*KQ_rowsum.x + KQ_rowsum_add.x;
|
|
@@ -244,60 +197,179 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
| 244 |
|
| 245 |
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale.x, KQ_max_scale.y);
|
| 246 |
#pragma unroll
|
| 247 |
-
for (int i = 0; i < D/
|
| 248 |
#pragma unroll
|
| 249 |
-
for (int l = 0; l <
|
| 250 |
VKQ_C[i].x[l] *= KQ_max_scale_h2;
|
| 251 |
}
|
| 252 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 253 |
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 257 |
#pragma unroll
|
| 258 |
-
for (int
|
| 259 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 260 |
}
|
| 261 |
|
| 262 |
-
// Load V data into tile with decreasing granularity for D for better memory bandwidth:
|
| 263 |
-
static_assert(KQ_stride % (4*nwarps) == 0, "out of bounds");
|
| 264 |
#pragma unroll
|
| 265 |
-
for (int
|
| 266 |
-
const int
|
| 267 |
-
const int i0_stop = D/2 - (D/2) % (1*stride_i);
|
| 268 |
-
const int stride_k = WARP_SIZE / stride_i;
|
| 269 |
|
|
|
|
| 270 |
#pragma unroll
|
| 271 |
-
|
| 272 |
-
|
| 273 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
#pragma unroll
|
| 275 |
-
for (int
|
| 276 |
-
const int
|
| 277 |
|
| 278 |
-
|
| 279 |
}
|
| 280 |
}
|
| 281 |
}
|
|
|
|
| 282 |
|
| 283 |
-
|
| 284 |
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += mma_C_VKQ::I) {
|
| 288 |
-
static_assert((KQ_stride/2) % (np*mma_A::K) == 0, "bad loop size");
|
| 289 |
-
#pragma unroll
|
| 290 |
-
for (int k00 = 0; k00 < KQ_stride/2; k00 += np*mma_A::K) {
|
| 291 |
-
const int k0 = k00 + (threadIdx.y % np)*mma_A::K;
|
| 292 |
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
}
|
| 297 |
}
|
|
|
|
|
|
|
|
|
|
| 298 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 299 |
__syncthreads();
|
| 300 |
}
|
|
|
|
| 301 |
|
| 302 |
// Finally, sum up partial KQ rowsums.
|
| 303 |
// The partial sums are spread across 8 threads each, does not need full reduce.
|
|
@@ -310,26 +382,26 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
| 310 |
// Write VKQ accumulators to shared memory in column-major format.
|
| 311 |
// It's faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
|
| 312 |
// Also for np > 1 the combination is done via these values in shared memory.
|
| 313 |
-
const int j_cwd = threadIdx.y*
|
| 314 |
#pragma unroll
|
| 315 |
-
for (int k0 = 0; k0 < D/2; k0 +=
|
| 316 |
-
const
|
| 317 |
|
| 318 |
#pragma unroll
|
| 319 |
-
for (int l = 0; l <
|
| 320 |
-
const int k = k0 +
|
| 321 |
|
| 322 |
-
|
| 323 |
}
|
| 324 |
}
|
| 325 |
|
| 326 |
-
const int j_cwmo = (threadIdx.x % (2*
|
| 327 |
-
const int j_cwm = threadIdx.y*(2*
|
| 328 |
const float2 KQ_cmr = make_float2(((const float *) &KQ_max)[j_cwmo], ((const float *) &KQ_rowsum)[j_cwmo]); // KQ combine max rowsum
|
| 329 |
|
| 330 |
-
if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*
|
| 331 |
// Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale.
|
| 332 |
-
((float2 *)
|
| 333 |
}
|
| 334 |
|
| 335 |
__syncthreads();
|
|
@@ -337,11 +409,11 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
| 337 |
static_assert(np == 1 || np == 2 || np == 4, "bad np");
|
| 338 |
if (np == 1) {
|
| 339 |
// No combination is needed, the meta data can be directly written from registers to VRAM.
|
| 340 |
-
if (needs_fixup && threadIdx.x <
|
| 341 |
float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
|
| 342 |
dstk_fixup_meta[j_cwm] = KQ_cmr;
|
| 343 |
}
|
| 344 |
-
if (is_fixup && threadIdx.x <
|
| 345 |
float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
|
| 346 |
dstk_fixup_meta[j_cwm] = KQ_cmr;
|
| 347 |
}
|
|
@@ -350,42 +422,40 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
| 350 |
// Warps with threadIdx.y % np != 0 must NOT return early.
|
| 351 |
// All threads must return simultaneously to avoid race conditions with work on the next tile.
|
| 352 |
|
| 353 |
-
float * meta_j = (float *)
|
| 354 |
|
| 355 |
float KQ_cm = -FLT_MAX/2; // KQ combine max per parallel warp.
|
| 356 |
-
if (np*
|
| 357 |
KQ_cm = meta_j[0];
|
| 358 |
}
|
| 359 |
|
| 360 |
float KQ_cmn = KQ_cm; // KQ combine max new, max between all parallel warps.
|
| 361 |
#pragma unroll
|
| 362 |
-
for (int offset = np*
|
| 363 |
KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE));
|
| 364 |
}
|
| 365 |
|
| 366 |
const float KQ_cms = expf(KQ_cm - KQ_cmn); // KQ combine max scale per warp.
|
| 367 |
float KQ_crs = 0.0f; // KQ combine rowsum, scaled sum of all parallel warps.
|
| 368 |
-
if (np*
|
| 369 |
KQ_crs = KQ_cms*meta_j[1];
|
| 370 |
}
|
| 371 |
#pragma unroll
|
| 372 |
-
for (int offset = np*
|
| 373 |
KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE);
|
| 374 |
}
|
| 375 |
|
| 376 |
// Write back combined meta data:
|
| 377 |
-
if (np*
|
| 378 |
-
meta_j
|
| 379 |
-
meta_j[1] = KQ_crs; // Combined KQ rowsums.
|
| 380 |
-
meta_j[2] = KQ_cms; // KQ max scales per parallel warp.
|
| 381 |
}
|
| 382 |
-
if (needs_fixup && threadIdx.x <
|
| 383 |
float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
|
| 384 |
-
dstk_fixup_meta[(threadIdx.y/np)*
|
| 385 |
}
|
| 386 |
-
if (is_fixup && threadIdx.x <
|
| 387 |
float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
|
| 388 |
-
dstk_fixup_meta[(threadIdx.y/np)*
|
| 389 |
}
|
| 390 |
}
|
| 391 |
|
|
@@ -404,6 +474,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
| 404 |
const int k0_stop = D/2 - (D/2) % (1*stride_k);
|
| 405 |
const int stride_j = WARP_SIZE / stride_k;
|
| 406 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 407 |
if (nwarps*stride_j > ncols && threadIdx.y*stride_j >= ncols) {
|
| 408 |
break;
|
| 409 |
}
|
|
@@ -411,12 +485,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
| 411 |
#pragma unroll
|
| 412 |
for (int j0_dst = 0; j0_dst < ncols; j0_dst += (nwarps/np)*stride_j) {
|
| 413 |
const int j_dst = j0_dst + (threadIdx.y/np)*stride_j + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
|
| 414 |
-
const int
|
| 415 |
|
| 416 |
if (!is_fixup && jt*ncols + j_dst >= ne01) {
|
| 417 |
continue;
|
| 418 |
}
|
| 419 |
-
const float * meta_j = (const float *)
|
| 420 |
#pragma unroll
|
| 421 |
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
|
| 422 |
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
|
|
@@ -424,8 +498,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
| 424 |
float2 dstk_val = make_float2(0.0f, 0.0f);
|
| 425 |
#pragma unroll
|
| 426 |
for (int ip = 0; ip < np; ++ip) {
|
| 427 |
-
const float KQ_crs = np == 1 ? 1.0f : meta_j[ip*
|
| 428 |
-
const float2 dstk_val_add = __half22float2(
|
| 429 |
dstk_val.x += dstk_val_add.x*KQ_crs;
|
| 430 |
dstk_val.y += dstk_val_add.y*KQ_crs;
|
| 431 |
}
|
|
@@ -450,7 +524,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
| 450 |
__syncthreads();
|
| 451 |
}
|
| 452 |
#else
|
| 453 |
-
|
| 454 |
#endif // NEW_MMA_AVAILABLE
|
| 455 |
}
|
| 456 |
|
|
@@ -494,6 +568,11 @@ static __global__ void flash_attn_ext_f16(
|
|
| 494 |
const int ne1,
|
| 495 |
const int ne2,
|
| 496 |
const int ne3) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 497 |
// Skip unused kernel variants for faster compilation:
|
| 498 |
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
| 499 |
NO_DEVICE_CODE;
|
|
@@ -504,6 +583,10 @@ static __global__ void flash_attn_ext_f16(
|
|
| 504 |
|
| 505 |
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
| 506 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 507 |
const int iter_k = ne11 / KQ_stride;
|
| 508 |
const int iter_j = (ne01 + (ncols - 1)) / ncols;
|
| 509 |
|
|
@@ -535,14 +618,12 @@ static __global__ void flash_attn_ext_f16(
|
|
| 535 |
constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
|
| 536 |
flash_attn_ext_f16_process_tile<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup>
|
| 537 |
(Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap,
|
| 538 |
-
|
| 539 |
-
jt, kb0_start, kb0_stop);
|
| 540 |
} else {
|
| 541 |
constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
|
| 542 |
flash_attn_ext_f16_process_tile<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup>
|
| 543 |
(Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap,
|
| 544 |
-
|
| 545 |
-
jt, kb0_start, kb0_stop);
|
| 546 |
}
|
| 547 |
|
| 548 |
kbc += iter_k;
|
|
@@ -571,24 +652,27 @@ static __global__ void flash_attn_ext_f16(
|
|
| 571 |
constexpr bool needs_fixup = false;
|
| 572 |
flash_attn_ext_f16_process_tile<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup>
|
| 573 |
(Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap,
|
| 574 |
-
|
| 575 |
-
jt, kb0_start, kb0_stop);
|
| 576 |
}
|
| 577 |
|
| 578 |
template <int D, int cols_per_block>
|
| 579 |
void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 580 |
-
typedef
|
| 581 |
-
typedef
|
| 582 |
|
| 583 |
-
static_assert(D %
|
| 584 |
-
static_assert(cols_per_block %
|
| 585 |
|
| 586 |
const ggml_tensor * KQV = dst;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 587 |
|
| 588 |
-
|
| 589 |
-
|
| 590 |
-
|
| 591 |
-
constexpr size_t nbytes_shared = std::max(KQ_stride, nwarps*mma_B::J) * (D + 8) * sizeof(half);
|
| 592 |
|
| 593 |
float logit_softcap;
|
| 594 |
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
|
|
|
| 1 |
#include "common.cuh"
|
| 2 |
+
#include "cp-async.cuh"
|
| 3 |
#include "mma.cuh"
|
| 4 |
#include "fattn-common.cuh"
|
| 5 |
|
| 6 |
+
using namespace ggml_cuda_mma;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
+
typedef tile<16, 8, half2> tile_A;
|
| 9 |
+
typedef tile< 8, 8, half2> tile_B;
|
| 10 |
+
typedef tile<16, 8, float> tile_C_KQ;
|
| 11 |
+
typedef tile<16, 4, half2> tile_C_VKQ;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
+
template<int D, int nwarps, int KQ_stride>
|
| 14 |
+
static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
|
| 15 |
+
const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int stride_KV) {
|
| 16 |
constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts.
|
|
|
|
| 17 |
|
| 18 |
+
// If cp.async is available, load up to the highest power of 2 in D asynchronously:
|
| 19 |
+
#ifdef CP_ASYNC_AVAILABLE
|
| 20 |
+
static_assert(D >= 64 && D < 512, "bad D");
|
| 21 |
+
constexpr int k0_sync_start = D/2 < 64 ? 32 : (D/2 < 128 ? 64 : 128);
|
| 22 |
|
| 23 |
+
const unsigned int tile_KV_32 = __cvta_generic_to_shared(tile_KV);
|
|
|
|
| 24 |
|
| 25 |
+
constexpr int preload = 64;
|
| 26 |
+
constexpr int h2_per_chunk = 16/sizeof(half2);
|
| 27 |
+
constexpr int chunks_per_row = k0_sync_start / h2_per_chunk;
|
| 28 |
+
constexpr int stride_i = WARP_SIZE / chunks_per_row;
|
| 29 |
+
#pragma unroll
|
| 30 |
+
for (int i0 = 0; i0 < KQ_stride; i0 += nwarps*stride_i) {
|
| 31 |
+
const int i = i0 + threadIdx.y*stride_i + (chunks_per_row == WARP_SIZE ? 0 : threadIdx.x / chunks_per_row);
|
| 32 |
+
const int k = (chunks_per_row == WARP_SIZE ? threadIdx.x : threadIdx.x % chunks_per_row)*h2_per_chunk;
|
| 33 |
|
| 34 |
+
cp_async_cg_16<preload>(tile_KV_32 + (i*D2_padded + k)*sizeof(half2), KV + i*stride_KV + k);
|
| 35 |
+
}
|
| 36 |
+
#else
|
| 37 |
+
constexpr int k0_sync_start = 0;
|
| 38 |
+
#endif // CP_ASYNC_AVAILABLE
|
| 39 |
+
static_assert(k0_sync_start % WARP_SIZE == 0, "bad k0_sync_start");
|
| 40 |
+
|
| 41 |
+
// If D is not a power of 2, the rest is loaded synchronously.
|
| 42 |
+
// K/V data is loaded with decreasing granularity for D for better memory bandwidth.
|
| 43 |
+
static_assert(KQ_stride % (4*nwarps) == 0, "out of bounds");
|
| 44 |
#pragma unroll
|
| 45 |
for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
|
| 46 |
+
const int k0_start = stride_k == WARP_SIZE ? k0_sync_start : D/2 - (D/2) % (2*stride_k);
|
| 47 |
+
const int k0_stop = D/2 - (D/2) % (1*stride_k);
|
| 48 |
+
const int stride_i = WARP_SIZE / stride_k;
|
| 49 |
|
| 50 |
+
if (k0_start == k0_stop || k0_stop <= k0_sync_start) {
|
| 51 |
+
continue;
|
| 52 |
}
|
| 53 |
|
| 54 |
#pragma unroll
|
| 55 |
+
for (int i0 = 0; i0 < KQ_stride; i0 += nwarps*stride_i) {
|
| 56 |
+
const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
|
| 57 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
#pragma unroll
|
| 59 |
+
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
|
| 60 |
+
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
|
| 61 |
|
| 62 |
+
tile_KV[i*D2_padded + k] = KV[i*stride_KV + k];
|
|
|
|
| 63 |
}
|
| 64 |
}
|
| 65 |
}
|
| 66 |
+
}
|
| 67 |
|
| 68 |
+
template<int D, int ncols, int nwarps, int KQ_stride, bool use_logit_softcap, bool needs_fixup, bool is_fixup, bool last_iter>
|
| 69 |
+
static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
| 70 |
+
const float2 * const __restrict__ Q_f2,
|
| 71 |
+
const half2 * const __restrict__ K_h2,
|
| 72 |
+
const half2 * const __restrict__ V_h2,
|
| 73 |
+
const half * const __restrict__ maskh,
|
| 74 |
+
float2 * const __restrict__ dstk,
|
| 75 |
+
float2 * const __restrict__ dstk_fixup,
|
| 76 |
+
const float scale,
|
| 77 |
+
const float slope,
|
| 78 |
+
const float logit_softcap,
|
| 79 |
+
const int ne01,
|
| 80 |
+
const int ne02,
|
| 81 |
+
const int stride_Q,
|
| 82 |
+
const int stride_KV,
|
| 83 |
+
const int stride_mask,
|
| 84 |
+
const int jt,
|
| 85 |
+
half2 * const __restrict__ tile_K,
|
| 86 |
+
half2 * const __restrict__ tile_V,
|
| 87 |
+
const tile_B * const __restrict__ Q_B,
|
| 88 |
+
tile_C_VKQ * const __restrict__ VKQ_C,
|
| 89 |
+
float2 & KQ_max,
|
| 90 |
+
float2 & KQ_rowsum,
|
| 91 |
+
const int kb0) {
|
| 92 |
+
#ifdef NEW_MMA_AVAILABLE
|
| 93 |
+
constexpr int np = nwarps*tile_B::I / ncols; // Number of parallel CUDA warps per Q column.
|
| 94 |
+
constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts.
|
| 95 |
|
| 96 |
+
const int k_VKQ_0 = kb0*KQ_stride;
|
| 97 |
+
tile_C_KQ KQ_C[KQ_stride/(np*tile_C_KQ::I)];
|
|
|
|
|
|
|
|
|
|
| 98 |
|
| 99 |
+
#ifdef CP_ASYNC_AVAILABLE
|
| 100 |
+
cp_async_wait_all();
|
| 101 |
__syncthreads();
|
| 102 |
+
flash_attn_ext_f16_load_tile<D, nwarps, KQ_stride>(V_h2 + k_VKQ_0*stride_KV, tile_V, stride_KV);
|
| 103 |
+
#else
|
| 104 |
+
flash_attn_ext_f16_load_tile<D, nwarps, KQ_stride>(K_h2 + k_VKQ_0*stride_KV, tile_K, stride_KV);
|
| 105 |
+
__syncthreads();
|
| 106 |
+
#endif // CP_ASYNC_AVAILABLE
|
| 107 |
|
| 108 |
+
// Calculate tile of KQ:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
#pragma unroll
|
| 110 |
+
for (int i_KQ_00 = 0; i_KQ_00 < KQ_stride; i_KQ_00 += np*tile_A::I) {
|
| 111 |
+
const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*tile_A::I;
|
| 112 |
#pragma unroll
|
| 113 |
+
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += tile_A::J) {
|
| 114 |
+
tile_A K_A;
|
| 115 |
+
load_ldmatrix(K_A, tile_K + i_KQ_0*D2_padded + k_KQ_0, D2_padded);
|
| 116 |
+
mma(KQ_C[i_KQ_00/(np*tile_A::I)], K_A, ((tile_B *) Q_B)[k_KQ_0/tile_A::J]);
|
|
|
|
| 117 |
}
|
| 118 |
+
}
|
| 119 |
|
| 120 |
+
#ifndef CP_ASYNC_AVAILABLE
|
| 121 |
+
__syncthreads(); // Only needed if tile_K == tile_V.
|
| 122 |
+
#endif // CP_ASYNC_AVAILABLE
|
| 123 |
|
| 124 |
+
if (use_logit_softcap) {
|
| 125 |
+
static_assert(KQ_stride % (np*tile_C_KQ::I) == 0, "bad loop size");
|
| 126 |
#pragma unroll
|
| 127 |
+
for (int i = 0; i < KQ_stride/(np*tile_C_KQ::I); ++i) {
|
| 128 |
#pragma unroll
|
| 129 |
+
for (int l = 0; l < tile_C_KQ::ne; ++l) {
|
| 130 |
+
KQ_C[i].x[l] = logit_softcap*tanhf(KQ_C[i].x[l]);
|
|
|
|
| 131 |
}
|
| 132 |
}
|
| 133 |
+
}
|
| 134 |
|
| 135 |
+
if (maskh) {
|
| 136 |
+
static_assert(KQ_stride % (np *tile_C_KQ::I) == 0, "bad loop size");
|
| 137 |
+
static_assert(ncols % (nwarps/np*tile_C_KQ::J) == 0, "bad loop size");
|
| 138 |
#pragma unroll
|
| 139 |
+
for (int i00 = 0; i00 < KQ_stride; i00 += np*tile_C_KQ::I) {
|
| 140 |
+
const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ::I;
|
| 141 |
#pragma unroll
|
| 142 |
+
for (int l = 0; l < tile_C_KQ::ne; ++l) {
|
| 143 |
+
const int i = i0 + tile_C_KQ::get_i(l);
|
| 144 |
+
const int j = (threadIdx.y / np)*tile_C_KQ::J + tile_C_KQ::get_j(l);
|
| 145 |
|
| 146 |
+
KQ_C[i00/(np*tile_C_KQ::I)].x[l] += slope*__half2float(maskh[j*stride_mask + k_VKQ_0 + i]);
|
|
|
|
| 147 |
}
|
| 148 |
}
|
| 149 |
+
}
|
| 150 |
|
| 151 |
+
// Calculate softmax for each KQ column using the current max. value.
|
| 152 |
+
// The divisor is stored in KQ_rowsum and will be applied at the end.
|
| 153 |
+
float2 KQ_max_new = KQ_max;
|
| 154 |
+
static_assert(KQ_stride % (np*tile_C_KQ::I) == 0, "bad loop size");
|
| 155 |
#pragma unroll
|
| 156 |
+
for (int k = 0; k < KQ_stride/(np*tile_C_KQ::I); ++k) {
|
| 157 |
#pragma unroll
|
| 158 |
+
for (int l0 = 0; l0 < tile_C_KQ::ne; l0 += 2) {
|
| 159 |
+
KQ_max_new.x = fmaxf(KQ_max_new.x, KQ_C[k].x[l0 + 0]);
|
| 160 |
+
KQ_max_new.y = fmaxf(KQ_max_new.y, KQ_C[k].x[l0 + 1]);
|
|
|
|
| 161 |
}
|
| 162 |
+
}
|
| 163 |
|
| 164 |
+
// Values per KQ column are spread across 8 threads, does not need full warp reduce:
|
| 165 |
#pragma unroll
|
| 166 |
+
for (int offset = 16; offset > 2; offset >>= 1) {
|
| 167 |
+
KQ_max_new.x = fmaxf(KQ_max_new.x, __shfl_xor_sync(0xFFFFFFFF, KQ_max_new.x, offset, WARP_SIZE));
|
| 168 |
+
KQ_max_new.y = fmaxf(KQ_max_new.y, __shfl_xor_sync(0xFFFFFFFF, KQ_max_new.y, offset, WARP_SIZE));
|
| 169 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
|
| 171 |
+
float2 KQ_rowsum_add = make_float2(0.0f, 0.0f);
|
| 172 |
+
static_assert(KQ_stride % (np*tile_C_KQ::I) == 0, "bad loop size");
|
| 173 |
#pragma unroll
|
| 174 |
+
for (int k = 0; k < KQ_stride/(np*tile_C_KQ::I); ++k) {
|
| 175 |
#pragma unroll
|
| 176 |
+
for (int l = 0; l < tile_C_KQ::ne; ++l) {
|
| 177 |
+
const float KQ_max_l = l % 2 == 0 ? KQ_max_new.x : KQ_max_new.y;
|
| 178 |
+
const float diff = KQ_C[k].x[l] - KQ_max_l;
|
| 179 |
+
KQ_C[k].x[l] = expf(diff);
|
|
|
|
|
|
|
|
|
|
| 180 |
|
| 181 |
+
if (l % 2 == 0) {
|
| 182 |
+
KQ_rowsum_add.x += KQ_C[k].x[l];
|
| 183 |
+
} else {
|
| 184 |
+
KQ_rowsum_add.y += KQ_C[k].x[l];
|
|
|
|
| 185 |
}
|
| 186 |
}
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
{
|
| 190 |
+
const float2 diff = make_float2(KQ_max.x - KQ_max_new.x, KQ_max.y - KQ_max_new.y);
|
| 191 |
+
const float2 KQ_max_scale = make_float2(expf(diff.x), expf(diff.y));
|
| 192 |
+
KQ_max = KQ_max_new;
|
| 193 |
|
| 194 |
// Scale previous KQ_rowsum to account for a potential increase in KQ_max:
|
| 195 |
KQ_rowsum.x = KQ_max_scale.x*KQ_rowsum.x + KQ_rowsum_add.x;
|
|
|
|
| 197 |
|
| 198 |
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale.x, KQ_max_scale.y);
|
| 199 |
#pragma unroll
|
| 200 |
+
for (int i = 0; i < D/tile_C_VKQ::I; ++i) {
|
| 201 |
#pragma unroll
|
| 202 |
+
for (int l = 0; l < tile_C_VKQ::ne; ++l) {
|
| 203 |
VKQ_C[i].x[l] *= KQ_max_scale_h2;
|
| 204 |
}
|
| 205 |
}
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
// Convert KQ C tiles into B tiles for VKQ calculation:
|
| 209 |
+
tile_B B[KQ_stride/(np*2*tile_B::J)];
|
| 210 |
+
static_assert(KQ_stride % (np*2*tile_B::J) == 0, "bad loop size");
|
| 211 |
+
#pragma unroll
|
| 212 |
+
for (int k = 0; k < KQ_stride/(np*2*tile_B::J); ++k) {
|
| 213 |
+
B[k] = get_transposed(get_half2(KQ_C[k]));
|
| 214 |
+
}
|
| 215 |
|
| 216 |
+
#ifdef CP_ASYNC_AVAILABLE
|
| 217 |
+
cp_async_wait_all();
|
| 218 |
+
__syncthreads();
|
| 219 |
+
if (!last_iter) {
|
| 220 |
+
flash_attn_ext_f16_load_tile<D, nwarps, KQ_stride>(K_h2 + (k_VKQ_0 + KQ_stride)*stride_KV, tile_K, stride_KV);
|
| 221 |
+
}
|
| 222 |
+
#else
|
| 223 |
+
flash_attn_ext_f16_load_tile<D, nwarps, KQ_stride>(V_h2 + k_VKQ_0*stride_KV, tile_V, stride_KV);
|
| 224 |
+
__syncthreads();
|
| 225 |
+
#endif // CP_ASYNC_AVAILABLE
|
| 226 |
+
|
| 227 |
+
// Calculate VKQ tile:
|
| 228 |
+
#pragma unroll
|
| 229 |
+
for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += tile_C_VKQ::I) {
|
| 230 |
+
static_assert((KQ_stride/2) % (np*tile_A::J) == 0, "bad loop size");
|
| 231 |
#pragma unroll
|
| 232 |
+
for (int k00 = 0; k00 < KQ_stride/2; k00 += np*tile_A::J) {
|
| 233 |
+
const int k0 = k00 + (threadIdx.y % np)*tile_A::J;
|
| 234 |
+
|
| 235 |
+
tile_A A;
|
| 236 |
+
load_ldmatrix_trans(A, tile_V + 2*k0*D2_padded + i_VKQ_0/2, D2_padded);
|
| 237 |
+
mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]);
|
| 238 |
+
}
|
| 239 |
+
}
|
| 240 |
+
|
| 241 |
+
#ifndef CP_ASYNC_AVAILABLE
|
| 242 |
+
__syncthreads(); // Only needed if tile_K == tile_V.
|
| 243 |
+
#endif // CP_ASYNC_AVAILABLE
|
| 244 |
+
|
| 245 |
+
#else
|
| 246 |
+
NO_DEVICE_CODE;
|
| 247 |
+
#endif // NEW_MMA_AVAILABLE
|
| 248 |
+
}
|
| 249 |
+
|
| 250 |
+
template<int D, int ncols, int nwarps, int KQ_stride, bool use_logit_softcap, bool needs_fixup, bool is_fixup>
|
| 251 |
+
static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
| 252 |
+
const float2 * const __restrict__ Q_f2,
|
| 253 |
+
const half2 * const __restrict__ K_h2,
|
| 254 |
+
const half2 * const __restrict__ V_h2,
|
| 255 |
+
const half * const __restrict__ maskh,
|
| 256 |
+
float2 * const __restrict__ dstk,
|
| 257 |
+
float2 * const __restrict__ dstk_fixup,
|
| 258 |
+
const float scale,
|
| 259 |
+
const float slope,
|
| 260 |
+
const float logit_softcap,
|
| 261 |
+
const int ne01,
|
| 262 |
+
const int ne02,
|
| 263 |
+
const int stride_Q,
|
| 264 |
+
const int stride_KV,
|
| 265 |
+
const int stride_mask,
|
| 266 |
+
const int jt,
|
| 267 |
+
const int kb0_start,
|
| 268 |
+
const int kb0_stop) {
|
| 269 |
+
#ifdef NEW_MMA_AVAILABLE
|
| 270 |
+
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
| 271 |
+
|
| 272 |
+
static_assert(nwarps*tile_B::I % ncols == 0, "bad nwarps");
|
| 273 |
+
constexpr int np = nwarps*tile_B::I / ncols; // Number of parallel CUDA warps per Q column.
|
| 274 |
+
|
| 275 |
+
static_assert(D % nwarps == 0, "bad D");
|
| 276 |
+
static_assert(KQ_stride % nwarps == 0, "bad KQ_stride");
|
| 277 |
+
|
| 278 |
+
constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts.
|
| 279 |
+
|
| 280 |
+
// Temporary shared buffer for loading K/V data with KQ_stride*D logical elements:
|
| 281 |
+
extern __shared__ half2 tile_K[];
|
| 282 |
+
#ifdef CP_ASYNC_AVAILABLE
|
| 283 |
+
half2 * tile_V = tile_K + KQ_stride*D2_padded;
|
| 284 |
+
#else
|
| 285 |
+
half2 * tile_V = tile_K;
|
| 286 |
+
#endif // CP_ASYNC_AVAILABLE
|
| 287 |
+
|
| 288 |
+
tile_B Q_B[D/(2*tile_B::J)];
|
| 289 |
+
tile_C_VKQ VKQ_C[D/tile_C_VKQ::I];
|
| 290 |
+
|
| 291 |
+
float2 KQ_rowsum = {0.0f, 0.0f};
|
| 292 |
+
float2 KQ_max = {-FLT_MAX/2.0f, -FLT_MAX/2.0f};
|
| 293 |
+
|
| 294 |
+
// Temporarily load Q data into tile_K, will be loaded into registers afterwards.
|
| 295 |
+
// The loading is done with decreasing granularity for D for better memory bandwidth.
|
| 296 |
+
const half2 scale_h2 = make_half2(scale, scale);
|
| 297 |
+
#pragma unroll
|
| 298 |
+
for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
|
| 299 |
+
const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k);
|
| 300 |
+
const int k0_stop = D/2 - (D/2) % (1*stride_k);
|
| 301 |
+
const int stride_j = WARP_SIZE / stride_k;
|
| 302 |
+
|
| 303 |
+
if (k0_start == k0_stop) {
|
| 304 |
+
continue;
|
| 305 |
+
}
|
| 306 |
+
|
| 307 |
+
if (nwarps*stride_j > ncols && threadIdx.y*stride_j >= ncols) {
|
| 308 |
+
break;
|
| 309 |
}
|
| 310 |
|
|
|
|
|
|
|
| 311 |
#pragma unroll
|
| 312 |
+
for (int j0 = 0; j0 < ncols; j0 += nwarps*stride_j) {
|
| 313 |
+
const int j = j0 + threadIdx.y*stride_j + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
|
|
|
|
|
|
|
| 314 |
|
| 315 |
+
if (jt*ncols + j < ne01) {
|
| 316 |
#pragma unroll
|
| 317 |
+
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
|
| 318 |
+
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
|
| 319 |
|
| 320 |
+
const float2 tmp = Q_f2[(jt*ncols + j)*stride_Q + k];
|
| 321 |
+
tile_K[j*D2_padded + k] = scale_h2 * make_half2(tmp.x, tmp.y);
|
| 322 |
+
}
|
| 323 |
+
} else {
|
| 324 |
#pragma unroll
|
| 325 |
+
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
|
| 326 |
+
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
|
| 327 |
|
| 328 |
+
tile_K[j*D2_padded + k] = make_half2(0.0f, 0.0f);
|
| 329 |
}
|
| 330 |
}
|
| 331 |
}
|
| 332 |
+
}
|
| 333 |
|
| 334 |
+
__syncthreads();
|
| 335 |
|
| 336 |
+
{
|
| 337 |
+
const int j0 = (threadIdx.y / np) * tile_B::I;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 338 |
|
| 339 |
+
#pragma unroll
|
| 340 |
+
for (int k0 = 0; k0 < D/2; k0 += tile_B::J) {
|
| 341 |
+
load_ldmatrix(Q_B[k0/tile_B::J], tile_K + j0*D2_padded + k0, D2_padded);
|
|
|
|
| 342 |
}
|
| 343 |
+
}
|
| 344 |
+
|
| 345 |
+
__syncthreads();
|
| 346 |
|
| 347 |
+
// Preload K data for first iteration when using cp_async:
|
| 348 |
+
#ifdef CP_ASYNC_AVAILABLE
|
| 349 |
+
flash_attn_ext_f16_load_tile<D, nwarps, KQ_stride>(K_h2 + kb0_start*KQ_stride*stride_KV, tile_K, stride_KV);
|
| 350 |
+
#endif // CP_ASYNC_AVAILABLE
|
| 351 |
+
|
| 352 |
+
// Iterate over ne11 == previous tokens:
|
| 353 |
+
for (int kb0 = kb0_start; kb0 < kb0_stop-1; ++kb0) {
|
| 354 |
+
constexpr bool last_iter = false;
|
| 355 |
+
flash_attn_ext_f16_iter<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup, last_iter>
|
| 356 |
+
(Q_f2, K_h2, V_h2, maskh, dstk, dstk_fixup, scale, slope, logit_softcap,
|
| 357 |
+
ne01, ne02, stride_Q, stride_KV, stride_mask, jt, tile_K, tile_V, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
|
| 358 |
+
}
|
| 359 |
+
{ // kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
|
| 360 |
+
constexpr bool last_iter = true;
|
| 361 |
+
flash_attn_ext_f16_iter<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup, last_iter>
|
| 362 |
+
(Q_f2, K_h2, V_h2, maskh, dstk, dstk_fixup, scale, slope, logit_softcap,
|
| 363 |
+
ne01, ne02, stride_Q, stride_KV, stride_mask, jt, tile_K, tile_V, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1);
|
| 364 |
+
}
|
| 365 |
+
|
| 366 |
+
// With cp_async there is no __syncthreads at the end of the iter,
|
| 367 |
+
// there can be a race condition on shared memory access for combining/writing back results.
|
| 368 |
+
#ifdef CP_ASYNC_AVAILABLE
|
| 369 |
+
if (nwarps*tile_B::I > KQ_stride) {
|
| 370 |
__syncthreads();
|
| 371 |
}
|
| 372 |
+
#endif // CP_ASYNC_AVAILABLE
|
| 373 |
|
| 374 |
// Finally, sum up partial KQ rowsums.
|
| 375 |
// The partial sums are spread across 8 threads each, does not need full reduce.
|
|
|
|
| 382 |
// Write VKQ accumulators to shared memory in column-major format.
|
| 383 |
// It's faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
|
| 384 |
// Also for np > 1 the combination is done via these values in shared memory.
|
| 385 |
+
const int j_cwd = threadIdx.y*tile_B::I + tile_B::get_i(-1); // j combine write data
|
| 386 |
#pragma unroll
|
| 387 |
+
for (int k0 = 0; k0 < D/2; k0 += tile_B::J) {
|
| 388 |
+
const tile_B B = get_transposed(VKQ_C[k0/tile_B::J]); // Conversion of C to B matrix puts it in column-major format.
|
| 389 |
|
| 390 |
#pragma unroll
|
| 391 |
+
for (int l = 0; l < tile_B::ne; ++l) {
|
| 392 |
+
const int k = k0 + tile_B::get_j(l);
|
| 393 |
|
| 394 |
+
tile_K[j_cwd*D2_padded + k] = B.x[l];
|
| 395 |
}
|
| 396 |
}
|
| 397 |
|
| 398 |
+
const int j_cwmo = (threadIdx.x % (2*tile_C_VKQ::J)) / tile_C_VKQ::J; // j combine write meta offset
|
| 399 |
+
const int j_cwm = threadIdx.y*(2*tile_C_VKQ::J) + 2*tile_C_VKQ::get_j(-1) + j_cwmo; // j combine write meta
|
| 400 |
const float2 KQ_cmr = make_float2(((const float *) &KQ_max)[j_cwmo], ((const float *) &KQ_rowsum)[j_cwmo]); // KQ combine max rowsum
|
| 401 |
|
| 402 |
+
if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*tile_C_VKQ::J) {
|
| 403 |
// Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale.
|
| 404 |
+
((float2 *) tile_K)[j_cwm*(D2_padded/2) + D/4] = KQ_cmr;
|
| 405 |
}
|
| 406 |
|
| 407 |
__syncthreads();
|
|
|
|
| 409 |
static_assert(np == 1 || np == 2 || np == 4, "bad np");
|
| 410 |
if (np == 1) {
|
| 411 |
// No combination is needed, the meta data can be directly written from registers to VRAM.
|
| 412 |
+
if (needs_fixup && threadIdx.x < tile_B::I) {
|
| 413 |
float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
|
| 414 |
dstk_fixup_meta[j_cwm] = KQ_cmr;
|
| 415 |
}
|
| 416 |
+
if (is_fixup && threadIdx.x < tile_B::I) {
|
| 417 |
float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
|
| 418 |
dstk_fixup_meta[j_cwm] = KQ_cmr;
|
| 419 |
}
|
|
|
|
| 422 |
// Warps with threadIdx.y % np != 0 must NOT return early.
|
| 423 |
// All threads must return simultaneously to avoid race conditions with work on the next tile.
|
| 424 |
|
| 425 |
+
float * meta_j = (float *) tile_K + (threadIdx.y*tile_B::I + threadIdx.x)*D2_padded + D/2;
|
| 426 |
|
| 427 |
float KQ_cm = -FLT_MAX/2; // KQ combine max per parallel warp.
|
| 428 |
+
if (np*tile_B::I == WARP_SIZE || threadIdx.x < np*tile_B::I) {
|
| 429 |
KQ_cm = meta_j[0];
|
| 430 |
}
|
| 431 |
|
| 432 |
float KQ_cmn = KQ_cm; // KQ combine max new, max between all parallel warps.
|
| 433 |
#pragma unroll
|
| 434 |
+
for (int offset = np*tile_B::I/2; offset >= tile_B::I; offset >>= 1) {
|
| 435 |
KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE));
|
| 436 |
}
|
| 437 |
|
| 438 |
const float KQ_cms = expf(KQ_cm - KQ_cmn); // KQ combine max scale per warp.
|
| 439 |
float KQ_crs = 0.0f; // KQ combine rowsum, scaled sum of all parallel warps.
|
| 440 |
+
if (np*tile_B::I == WARP_SIZE || threadIdx.x < np*tile_B::I) {
|
| 441 |
KQ_crs = KQ_cms*meta_j[1];
|
| 442 |
}
|
| 443 |
#pragma unroll
|
| 444 |
+
for (int offset = np*tile_B::I/2; offset >= tile_B::I; offset >>= 1) {
|
| 445 |
KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE);
|
| 446 |
}
|
| 447 |
|
| 448 |
// Write back combined meta data:
|
| 449 |
+
if (np*tile_B::I == WARP_SIZE || threadIdx.x < np*tile_B::I) {
|
| 450 |
+
*((float2 *) meta_j) = make_float2(KQ_cms, KQ_crs); // Combined KQ max scale + rowsum.
|
|
|
|
|
|
|
| 451 |
}
|
| 452 |
+
if (needs_fixup && threadIdx.x < tile_B::I) {
|
| 453 |
float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
|
| 454 |
+
dstk_fixup_meta[(threadIdx.y/np)*tile_B::I + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
|
| 455 |
}
|
| 456 |
+
if (is_fixup && threadIdx.x < tile_B::I) {
|
| 457 |
float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
|
| 458 |
+
dstk_fixup_meta[(threadIdx.y/np)*tile_B::I + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
|
| 459 |
}
|
| 460 |
}
|
| 461 |
|
|
|
|
| 474 |
const int k0_stop = D/2 - (D/2) % (1*stride_k);
|
| 475 |
const int stride_j = WARP_SIZE / stride_k;
|
| 476 |
|
| 477 |
+
if (k0_start == k0_stop) {
|
| 478 |
+
continue;
|
| 479 |
+
}
|
| 480 |
+
|
| 481 |
if (nwarps*stride_j > ncols && threadIdx.y*stride_j >= ncols) {
|
| 482 |
break;
|
| 483 |
}
|
|
|
|
| 485 |
#pragma unroll
|
| 486 |
for (int j0_dst = 0; j0_dst < ncols; j0_dst += (nwarps/np)*stride_j) {
|
| 487 |
const int j_dst = j0_dst + (threadIdx.y/np)*stride_j + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
|
| 488 |
+
const int j_tile_K = (j_dst/tile_B::I)*(np*tile_B::I) + j_dst % tile_B::I;
|
| 489 |
|
| 490 |
if (!is_fixup && jt*ncols + j_dst >= ne01) {
|
| 491 |
continue;
|
| 492 |
}
|
| 493 |
+
const float * meta_j = (const float *) tile_K + j_tile_K*D2_padded + D/2;
|
| 494 |
#pragma unroll
|
| 495 |
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
|
| 496 |
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
|
|
|
|
| 498 |
float2 dstk_val = make_float2(0.0f, 0.0f);
|
| 499 |
#pragma unroll
|
| 500 |
for (int ip = 0; ip < np; ++ip) {
|
| 501 |
+
const float KQ_crs = np == 1 ? 1.0f : meta_j[ip*tile_B::I*D2_padded + 0];
|
| 502 |
+
const float2 dstk_val_add = __half22float2(tile_K[(j_tile_K + ip*tile_B::I)*D2_padded + k]);
|
| 503 |
dstk_val.x += dstk_val_add.x*KQ_crs;
|
| 504 |
dstk_val.y += dstk_val_add.y*KQ_crs;
|
| 505 |
}
|
|
|
|
| 524 |
__syncthreads();
|
| 525 |
}
|
| 526 |
#else
|
| 527 |
+
NO_DEVICE_CODE;
|
| 528 |
#endif // NEW_MMA_AVAILABLE
|
| 529 |
}
|
| 530 |
|
|
|
|
| 568 |
const int ne1,
|
| 569 |
const int ne2,
|
| 570 |
const int ne3) {
|
| 571 |
+
#ifndef NEW_MMA_AVAILABLE
|
| 572 |
+
NO_DEVICE_CODE;
|
| 573 |
+
return;
|
| 574 |
+
#endif // NEW_MMA_AVAILABLE
|
| 575 |
+
|
| 576 |
// Skip unused kernel variants for faster compilation:
|
| 577 |
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
| 578 |
NO_DEVICE_CODE;
|
|
|
|
| 583 |
|
| 584 |
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
| 585 |
|
| 586 |
+
const int stride_Q = nb01 / sizeof(float2);
|
| 587 |
+
const int stride_KV = nb11 / sizeof(half2);
|
| 588 |
+
const int stride_mask = nb31 / sizeof(half);
|
| 589 |
+
|
| 590 |
const int iter_k = ne11 / KQ_stride;
|
| 591 |
const int iter_j = (ne01 + (ncols - 1)) / ncols;
|
| 592 |
|
|
|
|
| 618 |
constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
|
| 619 |
flash_attn_ext_f16_process_tile<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup>
|
| 620 |
(Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap,
|
| 621 |
+
ne01, ne02, stride_Q, stride_KV, stride_mask, jt, kb0_start, kb0_stop);
|
|
|
|
| 622 |
} else {
|
| 623 |
constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
|
| 624 |
flash_attn_ext_f16_process_tile<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup>
|
| 625 |
(Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap,
|
| 626 |
+
ne01, ne02, stride_Q, stride_KV, stride_mask, jt, kb0_start, kb0_stop);
|
|
|
|
| 627 |
}
|
| 628 |
|
| 629 |
kbc += iter_k;
|
|
|
|
| 652 |
constexpr bool needs_fixup = false;
|
| 653 |
flash_attn_ext_f16_process_tile<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup>
|
| 654 |
(Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap,
|
| 655 |
+
ne01, ne02, stride_Q, stride_KV, stride_mask, jt, kb0_start, kb0_stop);
|
|
|
|
| 656 |
}
|
| 657 |
|
| 658 |
template <int D, int cols_per_block>
|
| 659 |
void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 660 |
+
typedef tile<16, 8, half2> tile_A;
|
| 661 |
+
typedef tile< 8, 8, half2> tile_B;
|
| 662 |
|
| 663 |
+
static_assert(D % tile_B::J == 0, "bad D");
|
| 664 |
+
static_assert(cols_per_block % tile_B::I == 0, "bad cols_per_block");
|
| 665 |
|
| 666 |
const ggml_tensor * KQV = dst;
|
| 667 |
+
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
| 668 |
+
|
| 669 |
+
constexpr int KQ_stride = D <= 128 ? 64 : 32;
|
| 670 |
+
constexpr int nwarps = (KQ_stride == 32 && cols_per_block <= 16) ?
|
| 671 |
+
cols_per_block/tile_B::J * KQ_stride/tile_A::I : (cols_per_block <= 8 ? 4 : 8);
|
| 672 |
|
| 673 |
+
const int nrows_KQ = cp_async_available(cc) ? 2*KQ_stride : KQ_stride;
|
| 674 |
+
const int nrows_combine = nwarps*tile_B::J;
|
| 675 |
+
const size_t nbytes_shared = std::max(nrows_KQ, nrows_combine) * (D + 8) * sizeof(half);
|
|
|
|
| 676 |
|
| 677 |
float logit_softcap;
|
| 678 |
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
ggml/src/ggml-cuda/mma.cuh
CHANGED
|
@@ -4,11 +4,12 @@
|
|
| 4 |
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-multiply-accumulate-operation-using-mma-instruction
|
| 5 |
//
|
| 6 |
// Like with nvcuda::wmma there are three types of matrix tiles: A, B, and C with A @ B = C.
|
| 7 |
-
// A is a row-major matrix with shape
|
| 8 |
-
// B is a column-major matrix with shape K x
|
| 9 |
-
// C is a column-major matrix with shape
|
| 10 |
-
//
|
| 11 |
-
//
|
|
|
|
| 12 |
// All matrix tiles have ne physical 32 bit elements per warp.
|
| 13 |
//
|
| 14 |
// As described in the documentation, all pointers for load_ldmatrix must be to shared memory and aligned to 16 bytes.
|
|
@@ -23,7 +24,7 @@ static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) {
|
|
| 23 |
|
| 24 |
#ifdef NEW_MMA_AVAILABLE
|
| 25 |
asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
|
| 26 |
-
: "
|
| 27 |
#else
|
| 28 |
NO_DEVICE_CODE;
|
| 29 |
#endif // defined(NEW_MMA_AVAILABLE)
|
|
@@ -52,407 +53,267 @@ static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) {
|
|
| 52 |
|
| 53 |
#endif // CUDART_VERSION >= 11080
|
| 54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
|
| 80 |
-
|
|
|
|
|
|
|
| 81 |
#pragma unroll
|
| 82 |
-
for (int
|
| 83 |
-
x[
|
| 84 |
}
|
| 85 |
-
}
|
| 86 |
-
|
| 87 |
-
__device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) {
|
| 88 |
-
#ifdef NEW_MMA_AVAILABLE
|
| 89 |
-
int * xi = (int *) x;
|
| 90 |
-
const int * xs = (const int *) xs0 + (threadIdx.x%I)*stride;
|
| 91 |
-
asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
|
| 92 |
-
: "+r"(xi[0]), "+r"(xi[1])
|
| 93 |
-
: "l"(xs));
|
| 94 |
-
#else
|
| 95 |
-
load_generic(xs0, stride);
|
| 96 |
-
#endif // NEW_MMA_AVAILABLE
|
| 97 |
-
}
|
| 98 |
-
};
|
| 99 |
-
|
| 100 |
-
template <typename T>
|
| 101 |
-
struct mma_A_I16K8 {
|
| 102 |
-
static_assert(sizeof(T) == 4, "bad type size");
|
| 103 |
-
|
| 104 |
-
static constexpr int I = 16;
|
| 105 |
-
static constexpr int K = 8;
|
| 106 |
-
static constexpr int ne = 4;
|
| 107 |
-
|
| 108 |
-
T x[ne];
|
| 109 |
-
|
| 110 |
-
static __device__ __forceinline__ int get_i(const int l) {
|
| 111 |
-
const int ret = (l%2) * (I/2) + threadIdx.x / (K/2);
|
| 112 |
-
GGML_CUDA_ASSUME(ret >= 0);
|
| 113 |
-
GGML_CUDA_ASSUME(ret < I);
|
| 114 |
return ret;
|
| 115 |
}
|
| 116 |
|
| 117 |
-
static __device__ __forceinline__
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
|
|
|
| 121 |
return ret;
|
| 122 |
}
|
| 123 |
|
| 124 |
-
|
|
|
|
| 125 |
#pragma unroll
|
| 126 |
-
for (int l = 0; l < ne; ++l) {
|
| 127 |
-
x[l] = xs0[get_i(l)*stride +
|
| 128 |
}
|
| 129 |
}
|
| 130 |
|
| 131 |
-
|
|
|
|
|
|
|
| 132 |
#ifdef NEW_MMA_AVAILABLE
|
| 133 |
-
int * xi = (int *
|
| 134 |
-
const int * xs = (const int *) xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(
|
| 135 |
-
asm("ldmatrix.sync.aligned.m8n8.
|
| 136 |
-
: "
|
| 137 |
: "l"(xs));
|
| 138 |
#else
|
| 139 |
-
|
| 140 |
-
GGML_UNUSED(stride);
|
| 141 |
-
NO_DEVICE_CODE;
|
| 142 |
#endif // NEW_MMA_AVAILABLE
|
| 143 |
}
|
| 144 |
|
| 145 |
-
|
|
|
|
|
|
|
| 146 |
#ifdef NEW_MMA_AVAILABLE
|
| 147 |
-
int * xi = (int *
|
| 148 |
-
const int * xs = (const int *) xs0 + (threadIdx.x%
|
| 149 |
-
asm("ldmatrix.sync.aligned.m8n8.
|
| 150 |
-
: "
|
| 151 |
: "l"(xs));
|
| 152 |
#else
|
| 153 |
-
|
| 154 |
-
GGML_UNUSED(stride);
|
| 155 |
-
NO_DEVICE_CODE;
|
| 156 |
#endif // NEW_MMA_AVAILABLE
|
| 157 |
}
|
| 158 |
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
const int tmp = ggml_cuda_movmatrix(xi[1]);
|
| 164 |
-
xi[1] = ggml_cuda_movmatrix(xi[2]);
|
| 165 |
-
xi[2] = tmp;
|
| 166 |
-
|
| 167 |
-
xi[3] = ggml_cuda_movmatrix(xi[3]);
|
| 168 |
-
}
|
| 169 |
-
};
|
| 170 |
-
|
| 171 |
-
template <typename T>
|
| 172 |
-
struct mma_B_J8K4 {
|
| 173 |
-
static_assert(sizeof(T) == 4, "bad type size");
|
| 174 |
-
|
| 175 |
-
static constexpr int J = 8;
|
| 176 |
-
static constexpr int K = 4;
|
| 177 |
-
static constexpr int ne = 1;
|
| 178 |
-
|
| 179 |
-
T x[ne];
|
| 180 |
-
|
| 181 |
-
static __device__ __forceinline__ int get_j(const int /* l */) {
|
| 182 |
-
const int ret = threadIdx.x / K;
|
| 183 |
-
GGML_CUDA_ASSUME(ret >= 0);
|
| 184 |
-
GGML_CUDA_ASSUME(ret < J);
|
| 185 |
-
return ret;
|
| 186 |
-
}
|
| 187 |
-
|
| 188 |
-
static __device__ __forceinline__ int get_k(const int /* l */) {
|
| 189 |
-
const int ret = threadIdx.x % K;
|
| 190 |
-
GGML_CUDA_ASSUME(ret >= 0);
|
| 191 |
-
GGML_CUDA_ASSUME(ret < K);
|
| 192 |
-
return ret;
|
| 193 |
-
}
|
| 194 |
-
|
| 195 |
-
__device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) {
|
| 196 |
-
#pragma unroll
|
| 197 |
-
for (int l = 0; l < ne; ++l) {
|
| 198 |
-
x[l] = xs0[get_j(l)*stride + get_k(l)];
|
| 199 |
-
}
|
| 200 |
-
}
|
| 201 |
-
|
| 202 |
-
__device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) {
|
| 203 |
#ifdef NEW_MMA_AVAILABLE
|
| 204 |
-
int * xi = (int *) x;
|
| 205 |
-
const int * xs = (const int *) xs0 + (threadIdx.x%
|
| 206 |
-
asm("ldmatrix.sync.aligned.m8n8.
|
| 207 |
-
: "
|
|
|
|
| 208 |
#else
|
| 209 |
-
load_generic(xs0, stride);
|
| 210 |
#endif // NEW_MMA_AVAILABLE
|
| 211 |
}
|
| 212 |
-
};
|
| 213 |
-
|
| 214 |
-
template <typename T>
|
| 215 |
-
struct mma_B_J8K8 {
|
| 216 |
-
static_assert(sizeof(T) == 4, "bad type size");
|
| 217 |
-
|
| 218 |
-
static constexpr int J = 8;
|
| 219 |
-
static constexpr int K = 8;
|
| 220 |
-
static constexpr int ne = 2;
|
| 221 |
|
| 222 |
-
T
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
const int ret = threadIdx.x / (K/2);
|
| 226 |
-
GGML_CUDA_ASSUME(ret >= 0);
|
| 227 |
-
GGML_CUDA_ASSUME(ret < J);
|
| 228 |
-
return ret;
|
| 229 |
-
}
|
| 230 |
-
|
| 231 |
-
static __device__ __forceinline__ int get_k(const int l) {
|
| 232 |
-
const int ret = l * (K/2) + threadIdx.x % (K/2);
|
| 233 |
-
GGML_CUDA_ASSUME(ret >= 0);
|
| 234 |
-
GGML_CUDA_ASSUME(ret < K);
|
| 235 |
-
return ret;
|
| 236 |
-
}
|
| 237 |
-
|
| 238 |
-
__device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) {
|
| 239 |
-
#pragma unroll
|
| 240 |
-
for (int l = 0; l < ne; ++l) {
|
| 241 |
-
x[l] = xs0[get_j(l)*stride + get_k(l)];
|
| 242 |
-
}
|
| 243 |
-
}
|
| 244 |
-
|
| 245 |
-
__device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) {
|
| 246 |
#ifdef NEW_MMA_AVAILABLE
|
| 247 |
-
int * xi = (int *) x;
|
| 248 |
-
const int * xs = (const int *) xs0 + (threadIdx.x%
|
| 249 |
-
asm("ldmatrix.sync.aligned.m8n8.
|
| 250 |
-
: "
|
| 251 |
: "l"(xs));
|
| 252 |
#else
|
| 253 |
-
|
|
|
|
|
|
|
|
|
|
| 254 |
#endif // NEW_MMA_AVAILABLE
|
| 255 |
}
|
| 256 |
-
};
|
| 257 |
-
|
| 258 |
-
template <typename T>
|
| 259 |
-
struct mma_C_I16J8 {};
|
| 260 |
-
|
| 261 |
-
template <>
|
| 262 |
-
struct mma_C_I16J8<int> {
|
| 263 |
-
static constexpr int I = 16;
|
| 264 |
-
static constexpr int J = 8;
|
| 265 |
-
static constexpr int ne = 4;
|
| 266 |
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
static __device__ __forceinline__ int get_i(const int l) {
|
| 270 |
-
const int ret = (l/2) * (I/2) + threadIdx.x / (J/2);
|
| 271 |
-
GGML_CUDA_ASSUME(ret >= 0);
|
| 272 |
-
GGML_CUDA_ASSUME(ret < I);
|
| 273 |
-
return ret;
|
| 274 |
-
}
|
| 275 |
-
|
| 276 |
-
static __device__ __forceinline__ int get_j(const int l) {
|
| 277 |
-
const int ret = 2 * (threadIdx.x % (J/2)) + l%2;
|
| 278 |
-
GGML_CUDA_ASSUME(ret >= 0);
|
| 279 |
-
GGML_CUDA_ASSUME(ret < J);
|
| 280 |
-
return ret;
|
| 281 |
-
}
|
| 282 |
-
|
| 283 |
-
__device__ __forceinline__ void mma(const mma_A_I16K4<int> & mma_A, const mma_B_J8K4<int> & mma_B) {
|
| 284 |
#ifdef NEW_MMA_AVAILABLE
|
| 285 |
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
| 286 |
asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
|
| 287 |
-
: "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
|
| 288 |
-
: "r"(
|
| 289 |
#else
|
| 290 |
// On Turing m16n8k16 mma is not available, use 2x m8n8k16 mma instead:
|
| 291 |
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
|
| 292 |
-
: "+r"(x[0]), "+r"(x[1])
|
| 293 |
-
: "r"(
|
| 294 |
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
|
| 295 |
-
: "+r"(x[2]), "+r"(x[3])
|
| 296 |
-
: "r"(
|
| 297 |
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
| 298 |
#else
|
| 299 |
-
GGML_UNUSED(
|
| 300 |
-
GGML_UNUSED(
|
|
|
|
| 301 |
NO_DEVICE_CODE;
|
| 302 |
#endif // NEW_MMA_AVAILABLE
|
| 303 |
}
|
| 304 |
|
| 305 |
-
__device__ __forceinline__ void mma(
|
|
|
|
| 306 |
#ifdef NEW_MMA_AVAILABLE
|
| 307 |
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
| 308 |
asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
|
| 309 |
-
: "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
|
| 310 |
-
: "r"(
|
| 311 |
#else
|
| 312 |
// On Turing m16n8k32 mma is not available, use 4x m8n8k16 mma instead:
|
| 313 |
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
|
| 314 |
-
: "+r"(x[0]), "+r"(x[1])
|
| 315 |
-
: "r"(
|
| 316 |
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
|
| 317 |
-
: "+r"(x[2]), "+r"(x[3])
|
| 318 |
-
: "r"(
|
| 319 |
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
|
| 320 |
-
: "+r"(x[0]), "+r"(x[1])
|
| 321 |
-
: "r"(
|
| 322 |
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
|
| 323 |
-
: "+r"(x[2]), "+r"(x[3])
|
| 324 |
-
: "r"(
|
| 325 |
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
| 326 |
#else
|
| 327 |
-
GGML_UNUSED(
|
| 328 |
-
GGML_UNUSED(
|
|
|
|
| 329 |
NO_DEVICE_CODE;
|
| 330 |
#endif // NEW_MMA_AVAILABLE
|
| 331 |
}
|
| 332 |
-
};
|
| 333 |
-
|
| 334 |
-
template <>
|
| 335 |
-
struct mma_C_I16J8<half2> {
|
| 336 |
-
static constexpr int I = 16;
|
| 337 |
-
static constexpr int J = 4;
|
| 338 |
-
static constexpr int ne = 2;
|
| 339 |
-
|
| 340 |
-
half2 x[ne] = {{0.0f, 0.0f}, {0.0f, 0.0f}};
|
| 341 |
-
|
| 342 |
-
static __device__ __forceinline__ int get_i(const int l) {
|
| 343 |
-
const int ret = l * (I/2) + threadIdx.x / J;
|
| 344 |
-
GGML_CUDA_ASSUME(ret >= 0);
|
| 345 |
-
GGML_CUDA_ASSUME(ret < I);
|
| 346 |
-
return ret;
|
| 347 |
-
}
|
| 348 |
|
| 349 |
-
static __device__ __forceinline__
|
| 350 |
-
|
| 351 |
-
GGML_CUDA_ASSUME(ret >= 0);
|
| 352 |
-
GGML_CUDA_ASSUME(ret < J);
|
| 353 |
-
return ret;
|
| 354 |
-
}
|
| 355 |
-
|
| 356 |
-
__device__ __forceinline__ void mma(const mma_A_I16K8<half2> & mma_A, const mma_B_J8K8<half2> & mma_B) {
|
| 357 |
#ifdef NEW_MMA_AVAILABLE
|
| 358 |
-
int * Axi = (int *)
|
| 359 |
-
int * Bxi = (int *)
|
| 360 |
-
int
|
| 361 |
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
| 362 |
asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};"
|
| 363 |
-
: "+r"(
|
| 364 |
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
|
| 365 |
#else
|
| 366 |
// On Turing m16n8k16 mma is not available, use 2x m8n8k8 mma instead:
|
| 367 |
asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
|
| 368 |
-
: "+r"(
|
| 369 |
: "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
|
| 370 |
asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
|
| 371 |
-
: "+r"(
|
| 372 |
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]));
|
| 373 |
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
| 374 |
#else
|
| 375 |
-
GGML_UNUSED(
|
| 376 |
-
GGML_UNUSED(
|
|
|
|
| 377 |
NO_DEVICE_CODE;
|
| 378 |
#endif // NEW_MMA_AVAILABLE
|
| 379 |
}
|
| 380 |
|
| 381 |
-
__device__ __forceinline__
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
int * xi = (int *) x;
|
| 385 |
-
int * Bxi = (int *) mma_B.x;
|
| 386 |
-
Bxi[0] = ggml_cuda_movmatrix(xi[0]);
|
| 387 |
-
Bxi[1] = ggml_cuda_movmatrix(xi[1]);
|
| 388 |
-
|
| 389 |
-
return mma_B;
|
| 390 |
-
}
|
| 391 |
-
};
|
| 392 |
-
|
| 393 |
-
template <>
|
| 394 |
-
struct mma_C_I16J8<float> {
|
| 395 |
-
static constexpr int I = 16;
|
| 396 |
-
static constexpr int J = 8;
|
| 397 |
-
static constexpr int ne = 4;
|
| 398 |
-
|
| 399 |
-
float x[ne] = {0.0f, 0.0f, 0.0f, 0.0f};
|
| 400 |
-
|
| 401 |
-
static __device__ __forceinline__ int get_i(const int l) {
|
| 402 |
-
const int ret = (l/2) * (I/2) + threadIdx.x / (J/2);
|
| 403 |
-
GGML_CUDA_ASSUME(ret >= 0);
|
| 404 |
-
GGML_CUDA_ASSUME(ret < I);
|
| 405 |
-
return ret;
|
| 406 |
-
}
|
| 407 |
-
|
| 408 |
-
static __device__ __forceinline__ int get_j(const int l) {
|
| 409 |
-
const int ret = 2 * (threadIdx.x % (J/2)) + l%2;
|
| 410 |
-
GGML_CUDA_ASSUME(ret >= 0);
|
| 411 |
-
GGML_CUDA_ASSUME(ret < J);
|
| 412 |
-
return ret;
|
| 413 |
-
}
|
| 414 |
-
|
| 415 |
-
__device__ __forceinline__ void mma(const mma_A_I16K8<half2> & mma_A, const mma_B_J8K8<half2> & mma_B) {
|
| 416 |
#ifdef NEW_MMA_AVAILABLE
|
| 417 |
-
int * Axi = (int *)
|
| 418 |
-
int * Bxi = (int *)
|
| 419 |
-
int
|
| 420 |
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
| 421 |
asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
|
| 422 |
-
: "+r"(
|
| 423 |
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
|
| 424 |
#else
|
| 425 |
// On Turing m16n8k16 mma is not available, use 2x m8n8k8 mma instead:
|
| 426 |
asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
|
| 427 |
-
: "+r"(
|
| 428 |
: "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
|
| 429 |
asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
|
| 430 |
-
: "+r"(
|
| 431 |
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]));
|
| 432 |
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
| 433 |
#else
|
| 434 |
-
GGML_UNUSED(
|
| 435 |
-
GGML_UNUSED(
|
|
|
|
| 436 |
NO_DEVICE_CODE;
|
| 437 |
#endif // NEW_MMA_AVAILABLE
|
| 438 |
}
|
| 439 |
|
| 440 |
-
|
| 441 |
-
mma_B_J8K8<half2> mma_B;
|
| 442 |
-
mma_B.x[0] = make_half2(x[0], x[1]);
|
| 443 |
-
mma_B.x[1] = make_half2(x[2], x[3]);
|
| 444 |
-
|
| 445 |
-
int * Bxi = (int *) mma_B.x;
|
| 446 |
-
Bxi[0] = ggml_cuda_movmatrix(Bxi[0]);
|
| 447 |
-
Bxi[1] = ggml_cuda_movmatrix(Bxi[1]);
|
| 448 |
-
|
| 449 |
-
return mma_B;
|
| 450 |
-
}
|
| 451 |
-
|
| 452 |
-
__device__ __forceinline__ void load_generic(const float * __restrict__ xs0, const int & stride) {
|
| 453 |
-
#pragma unroll
|
| 454 |
-
for (int l = 0; l < ne; ++l) {
|
| 455 |
-
x[l] = xs0[get_j(l)*stride + get_i(l)];
|
| 456 |
-
}
|
| 457 |
-
}
|
| 458 |
-
};
|
|
|
|
| 4 |
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-multiply-accumulate-operation-using-mma-instruction
|
| 5 |
//
|
| 6 |
// Like with nvcuda::wmma there are three types of matrix tiles: A, B, and C with A @ B = C.
|
| 7 |
+
// A is a row-major matrix with shape M x K.
|
| 8 |
+
// B is a column-major matrix with shape K x N.
|
| 9 |
+
// C is a column-major matrix with shape M x N.
|
| 10 |
+
// A, B, and C are represented using the same fundamental data type: a row-major matrix with I rows and J columns.
|
| 11 |
+
// Note that J is measured in physical 32 bit elements instead of logical elements.
|
| 12 |
+
// The methods get_i and get_j can be used to get the physical 32 bit index of the lth element of a thread within a tile.
|
| 13 |
// All matrix tiles have ne physical 32 bit elements per warp.
|
| 14 |
//
|
| 15 |
// As described in the documentation, all pointers for load_ldmatrix must be to shared memory and aligned to 16 bytes.
|
|
|
|
| 24 |
|
| 25 |
#ifdef NEW_MMA_AVAILABLE
|
| 26 |
asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
|
| 27 |
+
: "=r"(ret) : "r"(x));
|
| 28 |
#else
|
| 29 |
NO_DEVICE_CODE;
|
| 30 |
#endif // defined(NEW_MMA_AVAILABLE)
|
|
|
|
| 53 |
|
| 54 |
#endif // CUDART_VERSION >= 11080
|
| 55 |
|
| 56 |
+
static __device__ __forceinline__ half2 ggml_cuda_movmatrix(const half2 x) {
|
| 57 |
+
half2 ret;
|
| 58 |
+
*((int *) &ret) = ggml_cuda_movmatrix(*((const int *) &x));
|
| 59 |
+
return ret;
|
| 60 |
+
}
|
| 61 |
|
| 62 |
+
namespace ggml_cuda_mma {
|
| 63 |
+
|
| 64 |
+
template <int I_, int J_, typename T>
|
| 65 |
+
struct tile {
|
| 66 |
+
static constexpr int I = I_;
|
| 67 |
+
static constexpr int J = J_;
|
| 68 |
+
static constexpr int ne = I * J / WARP_SIZE;
|
| 69 |
+
T x[ne] = {0};
|
| 70 |
+
|
| 71 |
+
static __device__ __forceinline__ int get_i(const int l) {
|
| 72 |
+
if constexpr (I == 8 && (J == 4 || J == 8)) {
|
| 73 |
+
return threadIdx.x / 4;
|
| 74 |
+
} else if constexpr (I == 16 && J == 8) {
|
| 75 |
+
return (l / 2) * 8 + threadIdx.x / 4;
|
| 76 |
+
} else {
|
| 77 |
+
static_assert(I == -1 && J == -1, "template specialization not implemented");
|
| 78 |
+
}
|
| 79 |
+
}
|
| 80 |
|
| 81 |
+
static __device__ __forceinline__ int get_j(const int l) {
|
| 82 |
+
if constexpr (I == 8 && J == 4) {
|
| 83 |
+
return threadIdx.x % 4;
|
| 84 |
+
} else if constexpr (I == 8 && J == 8) {
|
| 85 |
+
return 4 * l + threadIdx.x % 4;
|
| 86 |
+
} else if constexpr (I == 16 && J == 8) {
|
| 87 |
+
return 2 * (threadIdx.x % 4) + l % 2;
|
| 88 |
+
} else {
|
| 89 |
+
static_assert(I == -1 && J == -1, "template specialization not implemented");
|
| 90 |
+
}
|
| 91 |
+
}
|
| 92 |
+
};
|
| 93 |
+
|
| 94 |
+
template <int I_, int J_>
|
| 95 |
+
struct tile<I_, J_, half2> {
|
| 96 |
+
static constexpr int I = I_;
|
| 97 |
+
static constexpr int J = J_;
|
| 98 |
+
static constexpr int ne = I * J / WARP_SIZE;
|
| 99 |
+
half2 x[ne] = {{0.0f, 0.0f}};
|
| 100 |
+
|
| 101 |
+
static __device__ __forceinline__ int get_i(const int l) {
|
| 102 |
+
if constexpr (I == 8 && J == 8) {
|
| 103 |
+
return threadIdx.x / 4;
|
| 104 |
+
} else if constexpr (I == 16 && J == 4) {
|
| 105 |
+
return l * 8 + threadIdx.x / 4;
|
| 106 |
+
} else if constexpr (I == 16 && J == 8) {
|
| 107 |
+
return (l % 2) * 8 + threadIdx.x / 4;
|
| 108 |
+
} else {
|
| 109 |
+
static_assert(I == -1 && J == -1, "template specialization not implemented");
|
| 110 |
+
}
|
| 111 |
+
}
|
| 112 |
|
| 113 |
+
static __device__ __forceinline__ int get_j(const int l) {
|
| 114 |
+
if constexpr (I == 8 && J == 8) {
|
| 115 |
+
return l * 4 + threadIdx.x % 4;
|
| 116 |
+
} else if constexpr (I == 16 && J == 4) {
|
| 117 |
+
return threadIdx.x % 4;
|
| 118 |
+
} else if constexpr (I == 16 && J == 8) {
|
| 119 |
+
return (l / 2) * 4 + threadIdx.x % 4;
|
| 120 |
+
} else {
|
| 121 |
+
static_assert(I == -1 && J == -1, "template specialization not implemented");
|
| 122 |
+
}
|
| 123 |
+
}
|
| 124 |
+
};
|
| 125 |
|
| 126 |
+
template <int I, int J>
|
| 127 |
+
static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
|
| 128 |
+
tile<I, J/2, half2> ret;
|
| 129 |
#pragma unroll
|
| 130 |
+
for (int l0 = 0; l0 < tile_float.ne; l0 += 2) {
|
| 131 |
+
ret.x[l0/2] = make_half2(tile_float.x[l0 + 0], tile_float.x[l0 + 1]);
|
| 132 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
return ret;
|
| 134 |
}
|
| 135 |
|
| 136 |
+
static __device__ __forceinline__ tile<8, 8, half2> get_transposed(const tile<16, 4, half2> & t) {
|
| 137 |
+
tile<8, 8, half2> ret;
|
| 138 |
+
ret.x[0] = ggml_cuda_movmatrix(t.x[0]);
|
| 139 |
+
ret.x[1] = ggml_cuda_movmatrix(t.x[1]);
|
| 140 |
+
|
| 141 |
return ret;
|
| 142 |
}
|
| 143 |
|
| 144 |
+
template <int I, int J, typename T>
|
| 145 |
+
static __device__ __forceinline__ void load_generic(tile<I, J, T> & t, const T * __restrict__ xs0, const int stride) {
|
| 146 |
#pragma unroll
|
| 147 |
+
for (int l = 0; l < t.ne; ++l) {
|
| 148 |
+
t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
|
| 149 |
}
|
| 150 |
}
|
| 151 |
|
| 152 |
+
template <typename T>
|
| 153 |
+
static __device__ __forceinline__ void load_ldmatrix(
|
| 154 |
+
tile<8, 8, T> & t, const T * __restrict__ xs0, const int stride) {
|
| 155 |
#ifdef NEW_MMA_AVAILABLE
|
| 156 |
+
int * xi = (int *) t.x;
|
| 157 |
+
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + ((threadIdx.x / t.I) * (t.J / 2)) % t.J;
|
| 158 |
+
asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
|
| 159 |
+
: "=r"(xi[0]), "=r"(xi[1])
|
| 160 |
: "l"(xs));
|
| 161 |
#else
|
| 162 |
+
load_generic(t, xs0, stride);
|
|
|
|
|
|
|
| 163 |
#endif // NEW_MMA_AVAILABLE
|
| 164 |
}
|
| 165 |
|
| 166 |
+
template <typename T>
|
| 167 |
+
static __device__ __forceinline__ void load_ldmatrix(
|
| 168 |
+
tile<16, 4, T> & t, const T * __restrict__ xs0, const int stride) {
|
| 169 |
#ifdef NEW_MMA_AVAILABLE
|
| 170 |
+
int * xi = (int *) t.x;
|
| 171 |
+
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride;
|
| 172 |
+
asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
|
| 173 |
+
: "=r"(xi[0]), "=r"(xi[1])
|
| 174 |
: "l"(xs));
|
| 175 |
#else
|
| 176 |
+
load_generic(xs0, stride);
|
|
|
|
|
|
|
| 177 |
#endif // NEW_MMA_AVAILABLE
|
| 178 |
}
|
| 179 |
|
| 180 |
+
template <typename T>
|
| 181 |
+
static __device__ __forceinline__ void load_ldmatrix(
|
| 182 |
+
tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
#ifdef NEW_MMA_AVAILABLE
|
| 184 |
+
int * xi = (int * ) t.x;
|
| 185 |
+
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
|
| 186 |
+
asm volatile("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
|
| 187 |
+
: "=r"(xi[0]), "=r"(xi[1]), "=r"(xi[2]), "=r"(xi[3])
|
| 188 |
+
: "l"(xs));
|
| 189 |
#else
|
| 190 |
+
load_generic(t, xs0, stride);
|
| 191 |
#endif // NEW_MMA_AVAILABLE
|
| 192 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
|
| 194 |
+
template <typename T>
|
| 195 |
+
static __device__ __forceinline__ void load_ldmatrix_trans(
|
| 196 |
+
tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 197 |
#ifdef NEW_MMA_AVAILABLE
|
| 198 |
+
int * xi = (int * ) t.x;
|
| 199 |
+
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
|
| 200 |
+
asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.b16 {%0, %1, %2, %3}, [%4];"
|
| 201 |
+
: "=r"(xi[0]), "=r"(xi[2]), "=r"(xi[1]), "=r"(xi[3])
|
| 202 |
: "l"(xs));
|
| 203 |
#else
|
| 204 |
+
GGML_UNUSED(t);
|
| 205 |
+
GGML_UNUSED(xs0);
|
| 206 |
+
GGML_UNUSED(stride);
|
| 207 |
+
NO_DEVICE_CODE;
|
| 208 |
#endif // NEW_MMA_AVAILABLE
|
| 209 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
|
| 211 |
+
static __device__ __forceinline__ void mma(
|
| 212 |
+
tile<16, 8, int> & D, const tile<16, 4, int> & A, const tile<8, 4, int> & B) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
#ifdef NEW_MMA_AVAILABLE
|
| 214 |
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
| 215 |
asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
|
| 216 |
+
: "+r"(D.x[0]), "+r"(D.x[1]), "+r"(D.x[2]), "+r"(D.x[3])
|
| 217 |
+
: "r"(A.x[0]), "r"(A.x[1]), "r"(B.x[0]));
|
| 218 |
#else
|
| 219 |
// On Turing m16n8k16 mma is not available, use 2x m8n8k16 mma instead:
|
| 220 |
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
|
| 221 |
+
: "+r"(D.x[0]), "+r"(D.x[1])
|
| 222 |
+
: "r"(A.x[0]), "r"(B.x[0]));
|
| 223 |
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
|
| 224 |
+
: "+r"(D.x[2]), "+r"(D.x[3])
|
| 225 |
+
: "r"(A.x[1]), "r"(B.x[0]));
|
| 226 |
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
| 227 |
#else
|
| 228 |
+
GGML_UNUSED(D);
|
| 229 |
+
GGML_UNUSED(A);
|
| 230 |
+
GGML_UNUSED(B);
|
| 231 |
NO_DEVICE_CODE;
|
| 232 |
#endif // NEW_MMA_AVAILABLE
|
| 233 |
}
|
| 234 |
|
| 235 |
+
static __device__ __forceinline__ void mma(
|
| 236 |
+
tile<16, 8, int> & D, const tile<16, 8, int> & A, const tile<8, 8, int> & B) {
|
| 237 |
#ifdef NEW_MMA_AVAILABLE
|
| 238 |
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
| 239 |
asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
|
| 240 |
+
: "+r"(D.x[0]), "+r"(D.x[1]), "+r"(D.x[2]), "+r"(D.x[3])
|
| 241 |
+
: "r"(A.x[0]), "r"(A.x[1]), "r"(A.x[2]), "r"(A.x[3]), "r"(B.x[0]), "r"(B.x[1]));
|
| 242 |
#else
|
| 243 |
// On Turing m16n8k32 mma is not available, use 4x m8n8k16 mma instead:
|
| 244 |
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
|
| 245 |
+
: "+r"(D.x[0]), "+r"(D.x[1])
|
| 246 |
+
: "r"(A.x[0]), "r"(B.x[0]));
|
| 247 |
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
|
| 248 |
+
: "+r"(D.x[2]), "+r"(D.x[3])
|
| 249 |
+
: "r"(A.x[1]), "r"(B.x[0]));
|
| 250 |
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
|
| 251 |
+
: "+r"(D.x[0]), "+r"(D.x[1])
|
| 252 |
+
: "r"(A.x[2]), "r"(B.x[1]));
|
| 253 |
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
|
| 254 |
+
: "+r"(D.x[2]), "+r"(D.x[3])
|
| 255 |
+
: "r"(A.x[3]), "r"(B.x[1]));
|
| 256 |
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
| 257 |
#else
|
| 258 |
+
GGML_UNUSED(D);
|
| 259 |
+
GGML_UNUSED(A);
|
| 260 |
+
GGML_UNUSED(B);
|
| 261 |
NO_DEVICE_CODE;
|
| 262 |
#endif // NEW_MMA_AVAILABLE
|
| 263 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
|
| 265 |
+
static __device__ __forceinline__ void mma(
|
| 266 |
+
tile<16, 4, half2> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 267 |
#ifdef NEW_MMA_AVAILABLE
|
| 268 |
+
const int * Axi = (const int *) A.x;
|
| 269 |
+
const int * Bxi = (const int *) B.x;
|
| 270 |
+
int * Dxi = (int *) D.x;
|
| 271 |
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
| 272 |
asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};"
|
| 273 |
+
: "+r"(Dxi[0]), "+r"(Dxi[1])
|
| 274 |
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
|
| 275 |
#else
|
| 276 |
// On Turing m16n8k16 mma is not available, use 2x m8n8k8 mma instead:
|
| 277 |
asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
|
| 278 |
+
: "+r"(Dxi[0]), "+r"(Dxi[1])
|
| 279 |
: "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
|
| 280 |
asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
|
| 281 |
+
: "+r"(Dxi[0]), "+r"(Dxi[1])
|
| 282 |
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]));
|
| 283 |
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
| 284 |
#else
|
| 285 |
+
GGML_UNUSED(D);
|
| 286 |
+
GGML_UNUSED(A);
|
| 287 |
+
GGML_UNUSED(B);
|
| 288 |
NO_DEVICE_CODE;
|
| 289 |
#endif // NEW_MMA_AVAILABLE
|
| 290 |
}
|
| 291 |
|
| 292 |
+
static __device__ __forceinline__ void mma(
|
| 293 |
+
tile<16, 8, float> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 294 |
#ifdef NEW_MMA_AVAILABLE
|
| 295 |
+
const int * Axi = (const int *) A.x;
|
| 296 |
+
const int * Bxi = (const int *) B.x;
|
| 297 |
+
int * Dxi = (int *) D.x;
|
| 298 |
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
| 299 |
asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
|
| 300 |
+
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
|
| 301 |
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
|
| 302 |
#else
|
| 303 |
// On Turing m16n8k16 mma is not available, use 2x m8n8k8 mma instead:
|
| 304 |
asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
|
| 305 |
+
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
|
| 306 |
: "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
|
| 307 |
asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
|
| 308 |
+
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
|
| 309 |
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]));
|
| 310 |
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
| 311 |
#else
|
| 312 |
+
GGML_UNUSED(D);
|
| 313 |
+
GGML_UNUSED(A);
|
| 314 |
+
GGML_UNUSED(B);
|
| 315 |
NO_DEVICE_CODE;
|
| 316 |
#endif // NEW_MMA_AVAILABLE
|
| 317 |
}
|
| 318 |
|
| 319 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ggml/src/ggml-cuda/mmq.cuh
CHANGED
|
@@ -7,6 +7,8 @@
|
|
| 7 |
#include <climits>
|
| 8 |
#include <cstdint>
|
| 9 |
|
|
|
|
|
|
|
| 10 |
#define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available.
|
| 11 |
#define MMQ_ITER_K 256
|
| 12 |
#define MMQ_NWARPS 8
|
|
@@ -647,15 +649,15 @@ template <int mmq_x, int mmq_y, int nwarps, mmq_q8_1_ds_layout ds_layout>
|
|
| 647 |
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
|
| 648 |
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
|
| 649 |
|
| 650 |
-
typedef
|
| 651 |
-
typedef
|
| 652 |
-
typedef
|
| 653 |
|
| 654 |
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
| 655 |
constexpr int rows_per_warp = 2 * granularity;
|
| 656 |
-
constexpr int ntx = rows_per_warp/
|
| 657 |
|
| 658 |
-
y += (threadIdx.y % ntx) * (
|
| 659 |
|
| 660 |
const int * x_qs = (const int *) x;
|
| 661 |
const float * x_df = (const float *) x_qs + 2*WARP_SIZE;
|
|
@@ -663,8 +665,8 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
|
|
| 663 |
const float * y_df = (const float *) y;
|
| 664 |
const half2 * y_ds = (const half2 *) y;
|
| 665 |
|
| 666 |
-
|
| 667 |
-
float dA[ntx][
|
| 668 |
|
| 669 |
const int i0 = (threadIdx.y/ntx)*rows_per_warp;
|
| 670 |
|
|
@@ -674,12 +676,12 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
|
|
| 674 |
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
|
| 675 |
const int k0 = k00 + k01;
|
| 676 |
|
| 677 |
-
A[n][k01/QI8_0]
|
| 678 |
}
|
| 679 |
|
| 680 |
#pragma unroll
|
| 681 |
-
for (int l = 0; l <
|
| 682 |
-
const int i = i0 + n*
|
| 683 |
|
| 684 |
#pragma unroll
|
| 685 |
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
|
|
@@ -691,17 +693,17 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
|
|
| 691 |
}
|
| 692 |
|
| 693 |
#pragma unroll
|
| 694 |
-
for (int j0 = 0; j0 < mmq_x; j0 += ntx*
|
| 695 |
#pragma unroll
|
| 696 |
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
|
| 697 |
-
|
| 698 |
-
float dB[
|
| 699 |
|
| 700 |
-
|
| 701 |
|
| 702 |
#pragma unroll
|
| 703 |
-
for (int l = 0; l <
|
| 704 |
-
const int j = j0 +
|
| 705 |
|
| 706 |
if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) {
|
| 707 |
dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
|
|
@@ -712,12 +714,12 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
|
|
| 712 |
|
| 713 |
#pragma unroll
|
| 714 |
for (int n = 0; n < ntx; ++n) {
|
| 715 |
-
|
| 716 |
-
|
| 717 |
|
| 718 |
#pragma unroll
|
| 719 |
-
for (int l = 0; l <
|
| 720 |
-
sum[(j0/
|
| 721 |
}
|
| 722 |
}
|
| 723 |
}
|
|
@@ -758,23 +760,23 @@ template <int mmq_x, int mmq_y, int nwarps>
|
|
| 758 |
static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
|
| 759 |
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
|
| 760 |
|
| 761 |
-
typedef
|
| 762 |
-
typedef
|
| 763 |
-
typedef
|
| 764 |
|
| 765 |
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
| 766 |
constexpr int rows_per_warp = 2 * granularity;
|
| 767 |
-
constexpr int ntx = rows_per_warp/
|
| 768 |
|
| 769 |
-
y += (threadIdx.y % ntx) * (
|
| 770 |
|
| 771 |
const int * x_qs = (const int *) x;
|
| 772 |
const half2 * x_dm = (const half2 *) x_qs + 2*WARP_SIZE;
|
| 773 |
const int * y_qs = (const int *) y + 4;
|
| 774 |
const half2 * y_dm = (const half2 *) y;
|
| 775 |
|
| 776 |
-
|
| 777 |
-
float2 dmA[ntx][
|
| 778 |
|
| 779 |
const int i0 = (threadIdx.y/ntx)*rows_per_warp;
|
| 780 |
|
|
@@ -784,12 +786,12 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
|
|
| 784 |
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
|
| 785 |
const int k0 = k00 + k01;
|
| 786 |
|
| 787 |
-
A[n][k01/QI8_1]
|
| 788 |
}
|
| 789 |
|
| 790 |
#pragma unroll
|
| 791 |
-
for (int l = 0; l <
|
| 792 |
-
const int i = i0 + n*
|
| 793 |
|
| 794 |
#pragma unroll
|
| 795 |
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
|
|
@@ -801,30 +803,30 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
|
|
| 801 |
}
|
| 802 |
|
| 803 |
#pragma unroll
|
| 804 |
-
for (int j0 = 0; j0 < mmq_x; j0 += ntx*
|
| 805 |
#pragma unroll
|
| 806 |
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
|
| 807 |
-
|
| 808 |
-
float2 dsB[
|
| 809 |
|
| 810 |
-
|
| 811 |
|
| 812 |
#pragma unroll
|
| 813 |
-
for (int l = 0; l <
|
| 814 |
-
const int j = j0 +
|
| 815 |
|
| 816 |
dsB[l] = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]);
|
| 817 |
}
|
| 818 |
|
| 819 |
#pragma unroll
|
| 820 |
for (int n = 0; n < ntx; ++n) {
|
| 821 |
-
|
| 822 |
-
|
| 823 |
|
| 824 |
#pragma unroll
|
| 825 |
-
for (int l = 0; l <
|
| 826 |
-
sum[(j0/
|
| 827 |
-
sum[(j0/
|
| 828 |
}
|
| 829 |
}
|
| 830 |
}
|
|
@@ -868,26 +870,26 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
|
|
| 868 |
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
|
| 869 |
#ifdef NEW_MMA_AVAILABLE
|
| 870 |
|
| 871 |
-
typedef
|
| 872 |
-
typedef
|
| 873 |
-
typedef
|
| 874 |
-
typedef
|
| 875 |
|
| 876 |
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
| 877 |
constexpr int rows_per_warp = 2 * granularity;
|
| 878 |
-
constexpr int ntx = rows_per_warp/
|
| 879 |
|
| 880 |
-
y += (threadIdx.y % ntx) * (
|
| 881 |
|
| 882 |
const int * x_qs = (const int *) x;
|
| 883 |
const float * x_df = (const float *) x_qs + WARP_SIZE*2;
|
| 884 |
const int * y_qs = (const int *) y + 4;
|
| 885 |
const float * y_df = (const float *) y;
|
| 886 |
|
| 887 |
-
const int i0 = (threadIdx.y / ntx) * (ntx*
|
| 888 |
|
| 889 |
-
|
| 890 |
-
float dA[ntx][
|
| 891 |
|
| 892 |
#pragma unroll
|
| 893 |
for (int n = 0; n < ntx; ++n) {
|
|
@@ -895,12 +897,12 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
|
|
| 895 |
for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
|
| 896 |
const int k0 = k00 + k01;
|
| 897 |
|
| 898 |
-
((
|
| 899 |
}
|
| 900 |
|
| 901 |
#pragma unroll
|
| 902 |
-
for (int l = 0; l <
|
| 903 |
-
const int i = i0 + n*
|
| 904 |
|
| 905 |
#pragma unroll
|
| 906 |
for (int k01 = 0; k01 < WARP_SIZE; k01 += 4) {
|
|
@@ -912,32 +914,32 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
|
|
| 912 |
}
|
| 913 |
|
| 914 |
#pragma unroll
|
| 915 |
-
for (int j0 = 0; j0 < mmq_x; j0 += ntx*
|
| 916 |
#pragma unroll
|
| 917 |
for (int k01 = 0; k01 < WARP_SIZE; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) {
|
| 918 |
-
|
| 919 |
-
float dB[
|
| 920 |
|
| 921 |
// Here load_generic is faster than load_ldmatrix.
|
| 922 |
-
B[0]
|
| 923 |
-
B[1]
|
| 924 |
|
| 925 |
#pragma unroll
|
| 926 |
-
for (int l = 0; l <
|
| 927 |
-
const int j = j0 +
|
| 928 |
|
| 929 |
dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
|
| 930 |
}
|
| 931 |
|
| 932 |
#pragma unroll
|
| 933 |
for (int n = 0; n < ntx; ++n) {
|
| 934 |
-
|
| 935 |
-
C[0]
|
| 936 |
-
C[1]
|
| 937 |
|
| 938 |
#pragma unroll
|
| 939 |
-
for (int l = 0; l <
|
| 940 |
-
sum[(j0/
|
| 941 |
}
|
| 942 |
}
|
| 943 |
}
|
|
@@ -1056,27 +1058,27 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
|
|
| 1056 |
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
|
| 1057 |
#ifdef NEW_MMA_AVAILABLE
|
| 1058 |
|
| 1059 |
-
typedef
|
| 1060 |
-
typedef
|
| 1061 |
-
typedef
|
| 1062 |
-
typedef
|
| 1063 |
|
| 1064 |
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
| 1065 |
constexpr int rows_per_warp = 2 * granularity;
|
| 1066 |
-
constexpr int ntx = rows_per_warp/
|
| 1067 |
|
| 1068 |
-
y += (threadIdx.y % ntx) * (
|
| 1069 |
|
| 1070 |
const int * x_qs = (const int *) x;
|
| 1071 |
const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE*2;
|
| 1072 |
const int * y_qs = (const int *) y + 4;
|
| 1073 |
const half2 * y_ds = (const half2 *) y;
|
| 1074 |
|
| 1075 |
-
const int i0 = (threadIdx.y / ntx) * (ntx*
|
| 1076 |
|
| 1077 |
-
|
| 1078 |
-
float dA[ntx][
|
| 1079 |
-
float mA[ntx][
|
| 1080 |
|
| 1081 |
#pragma unroll
|
| 1082 |
for (int n = 0; n < ntx; ++n) {
|
|
@@ -1084,15 +1086,15 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
|
|
| 1084 |
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
|
| 1085 |
const int k0 = k00 + k01;
|
| 1086 |
|
| 1087 |
-
((
|
| 1088 |
}
|
| 1089 |
}
|
| 1090 |
|
| 1091 |
#pragma unroll
|
| 1092 |
for (int n = 0; n < ntx; ++n) {
|
| 1093 |
#pragma unroll
|
| 1094 |
-
for (int l = 0; l <
|
| 1095 |
-
const int i = i0 + n*
|
| 1096 |
|
| 1097 |
#pragma unroll
|
| 1098 |
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1/2) {
|
|
@@ -1107,58 +1109,58 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
|
|
| 1107 |
}
|
| 1108 |
|
| 1109 |
#pragma unroll
|
| 1110 |
-
for (int j0 = 0; j0 < mmq_x; j0 += ntx*
|
| 1111 |
-
float2 dB[
|
| 1112 |
|
| 1113 |
#pragma unroll
|
| 1114 |
-
for (int l = 0; l <
|
| 1115 |
-
const int j = j0 +
|
| 1116 |
|
| 1117 |
dB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K]);
|
| 1118 |
}
|
| 1119 |
|
| 1120 |
#pragma unroll
|
| 1121 |
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
|
| 1122 |
-
|
| 1123 |
|
| 1124 |
// Here load_generic is faster than load_ldmatrix.
|
| 1125 |
-
B[0]
|
| 1126 |
-
B[1]
|
| 1127 |
|
| 1128 |
-
|
| 1129 |
if (k01 >= WARP_SIZE * 3/4) {
|
| 1130 |
-
|
| 1131 |
A1.x[0] = 0x01010101;
|
| 1132 |
A1.x[1] = 0x01010101;
|
| 1133 |
-
Cm[0]
|
| 1134 |
-
Cm[1]
|
| 1135 |
}
|
| 1136 |
|
| 1137 |
#pragma unroll
|
| 1138 |
for (int n = 0; n < ntx; ++n) {
|
| 1139 |
-
|
| 1140 |
|
| 1141 |
-
Cd[0]
|
| 1142 |
-
Cd[1]
|
| 1143 |
|
| 1144 |
#pragma unroll
|
| 1145 |
-
for (int l = 0; l <
|
| 1146 |
float tmp = Cd[0].x[l]*dA[n][l/2][k01/4 + 0] + Cd[1].x[l]*dA[n][l/2][k01/4 + 1];
|
| 1147 |
if (k01 >= WARP_SIZE * 3/4) {
|
| 1148 |
tmp -= Cm[0].x[l]*mA[n][l/2][k01/4 + 0] + Cm[1].x[l]*mA[n][l/2][k01/4 + 1];
|
| 1149 |
}
|
| 1150 |
-
sum[(j0/
|
| 1151 |
}
|
| 1152 |
}
|
| 1153 |
}
|
| 1154 |
|
| 1155 |
#pragma unroll
|
| 1156 |
for (int k01 = 0; k01 < WARP_SIZE * 3/4; k01 += QI8_1) {
|
| 1157 |
-
float2 sB[
|
| 1158 |
|
| 1159 |
#pragma unroll
|
| 1160 |
-
for (int l = 0; l <
|
| 1161 |
-
const int j = j0 +
|
| 1162 |
|
| 1163 |
sB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
|
| 1164 |
}
|
|
@@ -1166,9 +1168,9 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
|
|
| 1166 |
#pragma unroll
|
| 1167 |
for (int n = 0; n < ntx; ++n) {
|
| 1168 |
#pragma unroll
|
| 1169 |
-
for (int l = 0; l <
|
| 1170 |
-
sum[(j0/
|
| 1171 |
-
sum[(j0/
|
| 1172 |
}
|
| 1173 |
}
|
| 1174 |
}
|
|
@@ -1708,15 +1710,15 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
|
| 1708 |
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
|
| 1709 |
#ifdef NEW_MMA_AVAILABLE
|
| 1710 |
|
| 1711 |
-
typedef
|
| 1712 |
-
typedef
|
| 1713 |
-
typedef
|
| 1714 |
|
| 1715 |
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
| 1716 |
constexpr int rows_per_warp = 2 * granularity;
|
| 1717 |
-
constexpr int ntx = rows_per_warp/
|
| 1718 |
|
| 1719 |
-
y += (threadIdx.y % ntx) * (
|
| 1720 |
|
| 1721 |
const int * x_qs = (const int *) x;
|
| 1722 |
const float * x_df = (const float *) x_qs + WARP_SIZE*2;
|
|
@@ -1724,11 +1726,11 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
|
| 1724 |
const int * y_qs = (const int *) y + 4;
|
| 1725 |
const float * y_df = (const float *) y;
|
| 1726 |
|
| 1727 |
-
const int i0 = (threadIdx.y / ntx) * (ntx*
|
| 1728 |
|
| 1729 |
-
|
| 1730 |
-
int
|
| 1731 |
-
float
|
| 1732 |
|
| 1733 |
#pragma unroll
|
| 1734 |
for (int n = 0; n < ntx; ++n) {
|
|
@@ -1736,8 +1738,8 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
|
| 1736 |
for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
|
| 1737 |
const int k0 = k00 + k01;
|
| 1738 |
|
| 1739 |
-
A[n][k01/4 + 0]
|
| 1740 |
-
A[n][k01/4 + 1]
|
| 1741 |
}
|
| 1742 |
|
| 1743 |
#pragma unroll
|
|
@@ -1745,8 +1747,8 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
|
| 1745 |
const int k0 = k00 + k01;
|
| 1746 |
|
| 1747 |
#pragma unroll
|
| 1748 |
-
for (int l = 0; l <
|
| 1749 |
-
const int i = i0 + n*
|
| 1750 |
|
| 1751 |
const int sc_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + k0/16];
|
| 1752 |
const int8_t * sc = (const int8_t *) &sc_packed;
|
|
@@ -1759,41 +1761,41 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
|
| 1759 |
}
|
| 1760 |
|
| 1761 |
#pragma unroll
|
| 1762 |
-
for (int l = 0; l <
|
| 1763 |
-
const int i = i0 + n*
|
| 1764 |
|
| 1765 |
dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q6_K];
|
| 1766 |
}
|
| 1767 |
}
|
| 1768 |
|
| 1769 |
#pragma unroll
|
| 1770 |
-
for (int j0 = 0; j0 < mmq_x; j0 += ntx*
|
| 1771 |
-
float tmp[ntx][
|
| 1772 |
|
| 1773 |
#pragma unroll
|
| 1774 |
for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
|
| 1775 |
-
|
| 1776 |
-
float dB[
|
| 1777 |
|
| 1778 |
// Here load_generic is faster than load_ldmatrix.
|
| 1779 |
-
B[0]
|
| 1780 |
-
B[1]
|
| 1781 |
|
| 1782 |
#pragma unroll
|
| 1783 |
-
for (int l = 0; l <
|
| 1784 |
-
const int j = j0 +
|
| 1785 |
|
| 1786 |
dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
|
| 1787 |
}
|
| 1788 |
|
| 1789 |
#pragma unroll
|
| 1790 |
for (int n = 0; n < ntx; ++n) {
|
| 1791 |
-
|
| 1792 |
-
C[0]
|
| 1793 |
-
C[1]
|
| 1794 |
|
| 1795 |
#pragma unroll
|
| 1796 |
-
for (int l = 0; l <
|
| 1797 |
tmp[n][l] += (C[0].x[l]*scA[n][l/2][k01/4 + 0] + C[1].x[l]*scA[n][l/2][k01/4 + 1])*dB[l%2];
|
| 1798 |
}
|
| 1799 |
}
|
|
@@ -1802,8 +1804,8 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
|
| 1802 |
#pragma unroll
|
| 1803 |
for (int n = 0; n < ntx; ++n) {
|
| 1804 |
#pragma unroll
|
| 1805 |
-
for (int l = 0; l <
|
| 1806 |
-
sum[(j0/
|
| 1807 |
}
|
| 1808 |
}
|
| 1809 |
}
|
|
@@ -2312,36 +2314,36 @@ template<int mmq_x, int mmq_y, int nwarps, bool need_check>
|
|
| 2312 |
static __device__ __forceinline__ void mmq_write_back_mma(
|
| 2313 |
const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max) {
|
| 2314 |
|
| 2315 |
-
typedef
|
| 2316 |
|
| 2317 |
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
| 2318 |
constexpr int rows_per_warp = 2 * granularity;
|
| 2319 |
-
constexpr int ntx = rows_per_warp/
|
| 2320 |
|
| 2321 |
-
const int i0 = (threadIdx.y / ntx) * (ntx*
|
| 2322 |
#ifdef NEW_MMA_AVAILABLE
|
| 2323 |
-
static_assert(nwarps*
|
| 2324 |
#endif // NEW_MMA_AVAILABLE
|
| 2325 |
|
| 2326 |
#pragma unroll
|
| 2327 |
-
for (int j0 = 0; j0 < mmq_x; j0 += ntx*
|
| 2328 |
#pragma unroll
|
| 2329 |
for (int n = 0; n < ntx; ++n) {
|
| 2330 |
#pragma unroll
|
| 2331 |
-
for (int l = 0; l <
|
| 2332 |
-
const int j = j0 + (threadIdx.y % ntx) *
|
| 2333 |
|
| 2334 |
if (j > j_max) {
|
| 2335 |
continue;
|
| 2336 |
}
|
| 2337 |
|
| 2338 |
-
const int i = i0 + n*
|
| 2339 |
|
| 2340 |
if (need_check && i > i_max) {
|
| 2341 |
continue;
|
| 2342 |
}
|
| 2343 |
|
| 2344 |
-
dst[j*stride + i] = sum[(j0/
|
| 2345 |
}
|
| 2346 |
}
|
| 2347 |
}
|
|
|
|
| 7 |
#include <climits>
|
| 8 |
#include <cstdint>
|
| 9 |
|
| 10 |
+
using namespace ggml_cuda_mma;
|
| 11 |
+
|
| 12 |
#define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available.
|
| 13 |
#define MMQ_ITER_K 256
|
| 14 |
#define MMQ_NWARPS 8
|
|
|
|
| 649 |
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
|
| 650 |
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
|
| 651 |
|
| 652 |
+
typedef tile<16, 8, int> tile_A;
|
| 653 |
+
typedef tile< 8, 8, int> tile_B;
|
| 654 |
+
typedef tile<16, 8, int> tile_C;
|
| 655 |
|
| 656 |
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
| 657 |
constexpr int rows_per_warp = 2 * granularity;
|
| 658 |
+
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
|
| 659 |
|
| 660 |
+
y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
|
| 661 |
|
| 662 |
const int * x_qs = (const int *) x;
|
| 663 |
const float * x_df = (const float *) x_qs + 2*WARP_SIZE;
|
|
|
|
| 665 |
const float * y_df = (const float *) y;
|
| 666 |
const half2 * y_ds = (const half2 *) y;
|
| 667 |
|
| 668 |
+
tile_A A[ntx][WARP_SIZE/QI8_0];
|
| 669 |
+
float dA[ntx][tile_C::ne/2][WARP_SIZE/QI8_0];
|
| 670 |
|
| 671 |
const int i0 = (threadIdx.y/ntx)*rows_per_warp;
|
| 672 |
|
|
|
|
| 676 |
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
|
| 677 |
const int k0 = k00 + k01;
|
| 678 |
|
| 679 |
+
load_ldmatrix(A[n][k01/QI8_0], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
|
| 680 |
}
|
| 681 |
|
| 682 |
#pragma unroll
|
| 683 |
+
for (int l = 0; l < tile_C::ne/2; ++l) {
|
| 684 |
+
const int i = i0 + n*tile_A::I + tile_C::get_i(2*l);
|
| 685 |
|
| 686 |
#pragma unroll
|
| 687 |
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
|
|
|
|
| 693 |
}
|
| 694 |
|
| 695 |
#pragma unroll
|
| 696 |
+
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
| 697 |
#pragma unroll
|
| 698 |
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
|
| 699 |
+
tile_B B;
|
| 700 |
+
float dB[tile_C::ne/2];
|
| 701 |
|
| 702 |
+
load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix
|
| 703 |
|
| 704 |
#pragma unroll
|
| 705 |
+
for (int l = 0; l < tile_C::ne/2; ++l) {
|
| 706 |
+
const int j = j0 + tile_C::get_j(l);
|
| 707 |
|
| 708 |
if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) {
|
| 709 |
dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
|
|
|
|
| 714 |
|
| 715 |
#pragma unroll
|
| 716 |
for (int n = 0; n < ntx; ++n) {
|
| 717 |
+
tile_C C;
|
| 718 |
+
mma(C, A[n][k01/QI8_0], B);
|
| 719 |
|
| 720 |
#pragma unroll
|
| 721 |
+
for (int l = 0; l < tile_C::ne; ++l) {
|
| 722 |
+
sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l]*dA[n][l/2][k01/QI8_0]*dB[l%2];
|
| 723 |
}
|
| 724 |
}
|
| 725 |
}
|
|
|
|
| 760 |
static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
|
| 761 |
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
|
| 762 |
|
| 763 |
+
typedef tile<16, 8, int> tile_A;
|
| 764 |
+
typedef tile< 8, 8, int> tile_B;
|
| 765 |
+
typedef tile<16, 8, int> tile_C;
|
| 766 |
|
| 767 |
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
| 768 |
constexpr int rows_per_warp = 2 * granularity;
|
| 769 |
+
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
|
| 770 |
|
| 771 |
+
y += (threadIdx.y % ntx) * (tile_B::J*MMQ_TILE_Y_K);
|
| 772 |
|
| 773 |
const int * x_qs = (const int *) x;
|
| 774 |
const half2 * x_dm = (const half2 *) x_qs + 2*WARP_SIZE;
|
| 775 |
const int * y_qs = (const int *) y + 4;
|
| 776 |
const half2 * y_dm = (const half2 *) y;
|
| 777 |
|
| 778 |
+
tile_A A[ntx][WARP_SIZE/QI8_1];
|
| 779 |
+
float2 dmA[ntx][tile_C::ne/2][WARP_SIZE/QI8_1];
|
| 780 |
|
| 781 |
const int i0 = (threadIdx.y/ntx)*rows_per_warp;
|
| 782 |
|
|
|
|
| 786 |
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
|
| 787 |
const int k0 = k00 + k01;
|
| 788 |
|
| 789 |
+
load_ldmatrix(A[n][k01/QI8_1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);
|
| 790 |
}
|
| 791 |
|
| 792 |
#pragma unroll
|
| 793 |
+
for (int l = 0; l < tile_C::ne/2; ++l) {
|
| 794 |
+
const int i = i0 + n*tile_A::I + tile_C::get_i(2*l);
|
| 795 |
|
| 796 |
#pragma unroll
|
| 797 |
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
|
|
|
|
| 803 |
}
|
| 804 |
|
| 805 |
#pragma unroll
|
| 806 |
+
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
| 807 |
#pragma unroll
|
| 808 |
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
|
| 809 |
+
tile_B B;
|
| 810 |
+
float2 dsB[tile_C::ne/2];
|
| 811 |
|
| 812 |
+
load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix
|
| 813 |
|
| 814 |
#pragma unroll
|
| 815 |
+
for (int l = 0; l < tile_C::ne/2; ++l) {
|
| 816 |
+
const int j = j0 + tile_C::get_j(l);
|
| 817 |
|
| 818 |
dsB[l] = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]);
|
| 819 |
}
|
| 820 |
|
| 821 |
#pragma unroll
|
| 822 |
for (int n = 0; n < ntx; ++n) {
|
| 823 |
+
tile_C C;
|
| 824 |
+
mma(C, A[n][k01/QI8_1], B);
|
| 825 |
|
| 826 |
#pragma unroll
|
| 827 |
+
for (int l = 0; l < tile_C::ne; ++l) {
|
| 828 |
+
sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA[n][l/2][k01/QI8_1].x*dsB[l%2].x*C.x[l];
|
| 829 |
+
sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA[n][l/2][k01/QI8_1].y*dsB[l%2].y;
|
| 830 |
}
|
| 831 |
}
|
| 832 |
}
|
|
|
|
| 870 |
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
|
| 871 |
#ifdef NEW_MMA_AVAILABLE
|
| 872 |
|
| 873 |
+
typedef tile<16, 4, int> tile_A;
|
| 874 |
+
typedef tile<16, 8, int> tile_A_8;
|
| 875 |
+
typedef tile< 8, 4, int> tile_B;
|
| 876 |
+
typedef tile<16, 8, int> tile_C;
|
| 877 |
|
| 878 |
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
| 879 |
constexpr int rows_per_warp = 2 * granularity;
|
| 880 |
+
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
|
| 881 |
|
| 882 |
+
y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
|
| 883 |
|
| 884 |
const int * x_qs = (const int *) x;
|
| 885 |
const float * x_df = (const float *) x_qs + WARP_SIZE*2;
|
| 886 |
const int * y_qs = (const int *) y + 4;
|
| 887 |
const float * y_df = (const float *) y;
|
| 888 |
|
| 889 |
+
const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I);
|
| 890 |
|
| 891 |
+
tile_A A[ntx][8];
|
| 892 |
+
float dA[ntx][tile_C::ne/2][8];
|
| 893 |
|
| 894 |
#pragma unroll
|
| 895 |
for (int n = 0; n < ntx; ++n) {
|
|
|
|
| 897 |
for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
|
| 898 |
const int k0 = k00 + k01;
|
| 899 |
|
| 900 |
+
load_ldmatrix(((tile_A_8 *) A[n])[k01/8], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
|
| 901 |
}
|
| 902 |
|
| 903 |
#pragma unroll
|
| 904 |
+
for (int l = 0; l < tile_C::ne/2; ++l) {
|
| 905 |
+
const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
|
| 906 |
|
| 907 |
#pragma unroll
|
| 908 |
for (int k01 = 0; k01 < WARP_SIZE; k01 += 4) {
|
|
|
|
| 914 |
}
|
| 915 |
|
| 916 |
#pragma unroll
|
| 917 |
+
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
| 918 |
#pragma unroll
|
| 919 |
for (int k01 = 0; k01 < WARP_SIZE; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) {
|
| 920 |
+
tile_B B[2];
|
| 921 |
+
float dB[tile_C::ne/2];
|
| 922 |
|
| 923 |
// Here load_generic is faster than load_ldmatrix.
|
| 924 |
+
load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K);
|
| 925 |
+
load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + (k01 + tile_B::J), MMQ_TILE_Y_K);
|
| 926 |
|
| 927 |
#pragma unroll
|
| 928 |
+
for (int l = 0; l < tile_C::ne/2; ++l) {
|
| 929 |
+
const int j = j0 + tile_C::get_j(l);
|
| 930 |
|
| 931 |
dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
|
| 932 |
}
|
| 933 |
|
| 934 |
#pragma unroll
|
| 935 |
for (int n = 0; n < ntx; ++n) {
|
| 936 |
+
tile_C C[2];
|
| 937 |
+
mma(C[0], A[n][k01/4 + 0], B[0]);
|
| 938 |
+
mma(C[1], A[n][k01/4 + 1], B[1]);
|
| 939 |
|
| 940 |
#pragma unroll
|
| 941 |
+
for (int l = 0; l < tile_C::ne; ++l) {
|
| 942 |
+
sum[(j0/tile_C::J + n)*tile_C::ne + l] += dB[l%2]*(C[0].x[l]*dA[n][l/2][k01/4 + 0] + C[1].x[l]*dA[n][l/2][k01/4 + 1]);
|
| 943 |
}
|
| 944 |
}
|
| 945 |
}
|
|
|
|
| 1058 |
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
|
| 1059 |
#ifdef NEW_MMA_AVAILABLE
|
| 1060 |
|
| 1061 |
+
typedef tile<16, 4, int> tile_A;
|
| 1062 |
+
typedef tile<16, 8, int> tile_A_8;
|
| 1063 |
+
typedef tile< 8, 4, int> tile_B;
|
| 1064 |
+
typedef tile<16, 8, int> tile_C;
|
| 1065 |
|
| 1066 |
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
| 1067 |
constexpr int rows_per_warp = 2 * granularity;
|
| 1068 |
+
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
|
| 1069 |
|
| 1070 |
+
y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
|
| 1071 |
|
| 1072 |
const int * x_qs = (const int *) x;
|
| 1073 |
const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE*2;
|
| 1074 |
const int * y_qs = (const int *) y + 4;
|
| 1075 |
const half2 * y_ds = (const half2 *) y;
|
| 1076 |
|
| 1077 |
+
const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I);
|
| 1078 |
|
| 1079 |
+
tile_A A[ntx][8];
|
| 1080 |
+
float dA[ntx][tile_C::ne/2][8];
|
| 1081 |
+
float mA[ntx][tile_C::ne/2][8];
|
| 1082 |
|
| 1083 |
#pragma unroll
|
| 1084 |
for (int n = 0; n < ntx; ++n) {
|
|
|
|
| 1086 |
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
|
| 1087 |
const int k0 = k00 + k01;
|
| 1088 |
|
| 1089 |
+
load_ldmatrix(((tile_A_8 *) A[n])[k01/QI8_1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
|
| 1090 |
}
|
| 1091 |
}
|
| 1092 |
|
| 1093 |
#pragma unroll
|
| 1094 |
for (int n = 0; n < ntx; ++n) {
|
| 1095 |
#pragma unroll
|
| 1096 |
+
for (int l = 0; l < tile_C::ne/2; ++l) {
|
| 1097 |
+
const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
|
| 1098 |
|
| 1099 |
#pragma unroll
|
| 1100 |
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1/2) {
|
|
|
|
| 1109 |
}
|
| 1110 |
|
| 1111 |
#pragma unroll
|
| 1112 |
+
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
| 1113 |
+
float2 dB[tile_C::ne/2];
|
| 1114 |
|
| 1115 |
#pragma unroll
|
| 1116 |
+
for (int l = 0; l < tile_C::ne/2; ++l) {
|
| 1117 |
+
const int j = j0 + tile_C::get_j(l);
|
| 1118 |
|
| 1119 |
dB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K]);
|
| 1120 |
}
|
| 1121 |
|
| 1122 |
#pragma unroll
|
| 1123 |
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
|
| 1124 |
+
tile_B B[2];
|
| 1125 |
|
| 1126 |
// Here load_generic is faster than load_ldmatrix.
|
| 1127 |
+
load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K);
|
| 1128 |
+
load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + (k01 + tile_B::J), MMQ_TILE_Y_K);
|
| 1129 |
|
| 1130 |
+
tile_C Cm[2];
|
| 1131 |
if (k01 >= WARP_SIZE * 3/4) {
|
| 1132 |
+
tile_A A1;
|
| 1133 |
A1.x[0] = 0x01010101;
|
| 1134 |
A1.x[1] = 0x01010101;
|
| 1135 |
+
mma(Cm[0], A1, B[0]);
|
| 1136 |
+
mma(Cm[1], A1, B[1]);
|
| 1137 |
}
|
| 1138 |
|
| 1139 |
#pragma unroll
|
| 1140 |
for (int n = 0; n < ntx; ++n) {
|
| 1141 |
+
tile_C Cd[2];
|
| 1142 |
|
| 1143 |
+
mma(Cd[0], A[n][k01/4 + 0], B[0]);
|
| 1144 |
+
mma(Cd[1], A[n][k01/4 + 1], B[1]);
|
| 1145 |
|
| 1146 |
#pragma unroll
|
| 1147 |
+
for (int l = 0; l < tile_C::ne; ++l) {
|
| 1148 |
float tmp = Cd[0].x[l]*dA[n][l/2][k01/4 + 0] + Cd[1].x[l]*dA[n][l/2][k01/4 + 1];
|
| 1149 |
if (k01 >= WARP_SIZE * 3/4) {
|
| 1150 |
tmp -= Cm[0].x[l]*mA[n][l/2][k01/4 + 0] + Cm[1].x[l]*mA[n][l/2][k01/4 + 1];
|
| 1151 |
}
|
| 1152 |
+
sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*(k01 < WARP_SIZE/2 ? dB[l%2].x : dB[l%2].y);
|
| 1153 |
}
|
| 1154 |
}
|
| 1155 |
}
|
| 1156 |
|
| 1157 |
#pragma unroll
|
| 1158 |
for (int k01 = 0; k01 < WARP_SIZE * 3/4; k01 += QI8_1) {
|
| 1159 |
+
float2 sB[tile_C::ne/2];
|
| 1160 |
|
| 1161 |
#pragma unroll
|
| 1162 |
+
for (int l = 0; l < tile_C::ne/2; ++l) {
|
| 1163 |
+
const int j = j0 + tile_C::get_j(l);
|
| 1164 |
|
| 1165 |
sB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
|
| 1166 |
}
|
|
|
|
| 1168 |
#pragma unroll
|
| 1169 |
for (int n = 0; n < ntx; ++n) {
|
| 1170 |
#pragma unroll
|
| 1171 |
+
for (int l = 0; l < tile_C::ne; ++l) {
|
| 1172 |
+
sum[(j0/tile_C::J + n)*tile_C::ne + l] -= mA[n][l/2][k01/4 + 0]*sB[l%2].x;
|
| 1173 |
+
sum[(j0/tile_C::J + n)*tile_C::ne + l] -= mA[n][l/2][k01/4 + 1]*sB[l%2].y;
|
| 1174 |
}
|
| 1175 |
}
|
| 1176 |
}
|
|
|
|
| 1710 |
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
|
| 1711 |
#ifdef NEW_MMA_AVAILABLE
|
| 1712 |
|
| 1713 |
+
typedef tile<16, 4, int> tile_A;
|
| 1714 |
+
typedef tile< 8, 4, int> tile_B;
|
| 1715 |
+
typedef tile<16, 8, int> tile_C;
|
| 1716 |
|
| 1717 |
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
| 1718 |
constexpr int rows_per_warp = 2 * granularity;
|
| 1719 |
+
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
|
| 1720 |
|
| 1721 |
+
y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
|
| 1722 |
|
| 1723 |
const int * x_qs = (const int *) x;
|
| 1724 |
const float * x_df = (const float *) x_qs + WARP_SIZE*2;
|
|
|
|
| 1726 |
const int * y_qs = (const int *) y + 4;
|
| 1727 |
const float * y_df = (const float *) y;
|
| 1728 |
|
| 1729 |
+
const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I);
|
| 1730 |
|
| 1731 |
+
tile_A A[ntx][8];
|
| 1732 |
+
int scA[ntx][tile_C::ne/2][8];
|
| 1733 |
+
float dA[ntx][tile_C::ne/2];
|
| 1734 |
|
| 1735 |
#pragma unroll
|
| 1736 |
for (int n = 0; n < ntx; ++n) {
|
|
|
|
| 1738 |
for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
|
| 1739 |
const int k0 = k00 + k01;
|
| 1740 |
|
| 1741 |
+
load_ldmatrix(A[n][k01/4 + 0], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + 0), MMQ_MMA_TILE_X_K_Q6_K);
|
| 1742 |
+
load_ldmatrix(A[n][k01/4 + 1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + tile_A::J), MMQ_MMA_TILE_X_K_Q6_K);
|
| 1743 |
}
|
| 1744 |
|
| 1745 |
#pragma unroll
|
|
|
|
| 1747 |
const int k0 = k00 + k01;
|
| 1748 |
|
| 1749 |
#pragma unroll
|
| 1750 |
+
for (int l = 0; l < tile_C::ne/2; ++l) {
|
| 1751 |
+
const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
|
| 1752 |
|
| 1753 |
const int sc_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + k0/16];
|
| 1754 |
const int8_t * sc = (const int8_t *) &sc_packed;
|
|
|
|
| 1761 |
}
|
| 1762 |
|
| 1763 |
#pragma unroll
|
| 1764 |
+
for (int l = 0; l < tile_C::ne/2; ++l) {
|
| 1765 |
+
const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
|
| 1766 |
|
| 1767 |
dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q6_K];
|
| 1768 |
}
|
| 1769 |
}
|
| 1770 |
|
| 1771 |
#pragma unroll
|
| 1772 |
+
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
| 1773 |
+
float tmp[ntx][tile_C::ne] = {{0.0f}};
|
| 1774 |
|
| 1775 |
#pragma unroll
|
| 1776 |
for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
|
| 1777 |
+
tile_B B[2];
|
| 1778 |
+
float dB[tile_C::ne/2];
|
| 1779 |
|
| 1780 |
// Here load_generic is faster than load_ldmatrix.
|
| 1781 |
+
load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + 0 + k01, MMQ_TILE_Y_K);
|
| 1782 |
+
load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + tile_B::J + k01, MMQ_TILE_Y_K);
|
| 1783 |
|
| 1784 |
#pragma unroll
|
| 1785 |
+
for (int l = 0; l < tile_C::ne/2; ++l) {
|
| 1786 |
+
const int j = j0 + tile_C::get_j(l);
|
| 1787 |
|
| 1788 |
dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
|
| 1789 |
}
|
| 1790 |
|
| 1791 |
#pragma unroll
|
| 1792 |
for (int n = 0; n < ntx; ++n) {
|
| 1793 |
+
tile_C C[2];
|
| 1794 |
+
mma(C[0], A[n][k01/4 + 0], B[0]);
|
| 1795 |
+
mma(C[1], A[n][k01/4 + 1], B[1]);
|
| 1796 |
|
| 1797 |
#pragma unroll
|
| 1798 |
+
for (int l = 0; l < tile_C::ne; ++l) {
|
| 1799 |
tmp[n][l] += (C[0].x[l]*scA[n][l/2][k01/4 + 0] + C[1].x[l]*scA[n][l/2][k01/4 + 1])*dB[l%2];
|
| 1800 |
}
|
| 1801 |
}
|
|
|
|
| 1804 |
#pragma unroll
|
| 1805 |
for (int n = 0; n < ntx; ++n) {
|
| 1806 |
#pragma unroll
|
| 1807 |
+
for (int l = 0; l < tile_C::ne; ++l) {
|
| 1808 |
+
sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp[n][l]*dA[n][l/2];
|
| 1809 |
}
|
| 1810 |
}
|
| 1811 |
}
|
|
|
|
| 2314 |
static __device__ __forceinline__ void mmq_write_back_mma(
|
| 2315 |
const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max) {
|
| 2316 |
|
| 2317 |
+
typedef tile<16, 8, int> tile_C;
|
| 2318 |
|
| 2319 |
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
| 2320 |
constexpr int rows_per_warp = 2 * granularity;
|
| 2321 |
+
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
|
| 2322 |
|
| 2323 |
+
const int i0 = (threadIdx.y / ntx) * (ntx*tile_C::I);
|
| 2324 |
#ifdef NEW_MMA_AVAILABLE
|
| 2325 |
+
static_assert(nwarps*tile_C::I == mmq_y, "nwarps*tile_C::I != mmq_y");
|
| 2326 |
#endif // NEW_MMA_AVAILABLE
|
| 2327 |
|
| 2328 |
#pragma unroll
|
| 2329 |
+
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
| 2330 |
#pragma unroll
|
| 2331 |
for (int n = 0; n < ntx; ++n) {
|
| 2332 |
#pragma unroll
|
| 2333 |
+
for (int l = 0; l < tile_C::ne; ++l) {
|
| 2334 |
+
const int j = j0 + (threadIdx.y % ntx) * tile_C::J + tile_C::get_j(l);
|
| 2335 |
|
| 2336 |
if (j > j_max) {
|
| 2337 |
continue;
|
| 2338 |
}
|
| 2339 |
|
| 2340 |
+
const int i = i0 + n*tile_C::I + tile_C::get_i(l);
|
| 2341 |
|
| 2342 |
if (need_check && i > i_max) {
|
| 2343 |
continue;
|
| 2344 |
}
|
| 2345 |
|
| 2346 |
+
dst[j*stride + i] = sum[(j0/tile_C::J + n)*tile_C::ne + l];
|
| 2347 |
}
|
| 2348 |
}
|
| 2349 |
}
|