JohannesGaessler Diego Devesa commited on
Commit
5b9980d
·
1 Parent(s): 3c2171d

CUDA: use async data loading for FlashAttention (llama/11894)

Browse files

* CUDA: use async data loading for FlashAttention

---------

Co-authored-by: Diego Devesa <[email protected]>

ggml/src/ggml-cuda/common.cuh CHANGED
@@ -41,12 +41,13 @@
41
  #define CUDART_HMAX 11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed)
42
  #define CUDART_HMASK 12000 // CUDA 12.0, min. ver. for half2 -> uint mask comparisons
43
 
44
- #define GGML_CUDA_CC_PASCAL 600
45
- #define GGML_CUDA_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
46
- #define GGML_CUDA_CC_VOLTA 700
47
- #define GGML_CUDA_CC_TURING 750
48
- #define GGML_CUDA_CC_AMPERE 800
49
- #define GGML_CUDA_CC_OFFSET_AMD 0x1000000
 
50
 
51
  // GCN/CNDA, wave size is 64
52
  #define GGML_CUDA_CC_GCN4 (GGML_CUDA_CC_OFFSET_AMD + 0x803) // Tonga, Fiji, Polaris, minimum for fast fp16
@@ -199,6 +200,10 @@ typedef float2 dfloat2;
199
  #define NEW_MMA_AVAILABLE
200
  #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
201
 
 
 
 
 
202
  #if !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
203
  #define FLASH_ATTN_AVAILABLE
204
  #endif // !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
@@ -231,6 +236,10 @@ static bool new_mma_available(const int cc) {
231
  return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
232
  }
233
 
 
 
 
 
234
  static constexpr __device__ int ggml_cuda_get_physical_warp_size() {
235
  #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
236
  return __AMDGCN_WAVEFRONT_SIZE;
 
41
  #define CUDART_HMAX 11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed)
42
  #define CUDART_HMASK 12000 // CUDA 12.0, min. ver. for half2 -> uint mask comparisons
43
 
44
+ #define GGML_CUDA_CC_PASCAL 600
45
+ #define GGML_CUDA_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
46
+ #define GGML_CUDA_CC_VOLTA 700
47
+ #define GGML_CUDA_CC_TURING 750
48
+ #define GGML_CUDA_CC_AMPERE 800
49
+ #define GGML_CUDA_CC_ADA_LOVELACE 890
50
+ #define GGML_CUDA_CC_OFFSET_AMD 0x1000000
51
 
52
  // GCN/CNDA, wave size is 64
53
  #define GGML_CUDA_CC_GCN4 (GGML_CUDA_CC_OFFSET_AMD + 0x803) // Tonga, Fiji, Polaris, minimum for fast fp16
 
200
  #define NEW_MMA_AVAILABLE
201
  #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
202
 
203
+ #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
204
+ #define CP_ASYNC_AVAILABLE
205
+ #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
206
+
207
  #if !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
208
  #define FLASH_ATTN_AVAILABLE
209
  #endif // !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
 
236
  return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
237
  }
238
 
239
+ static bool cp_async_available(const int cc) {
240
+ return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE;
241
+ }
242
+
243
  static constexpr __device__ int ggml_cuda_get_physical_warp_size() {
244
  #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
245
  return __AMDGCN_WAVEFRONT_SIZE;
ggml/src/ggml-cuda/cp-async.cuh ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Simplified API for asynchronous data loading.
2
+
3
+ #include "common.cuh"
4
+
5
+ // Copies data from global to shared memory, cg == cache global.
6
+ // Both the src and dst pointers must be aligned to 16 bit.
7
+ // Shared memory uses 32 bit addressing, the pointer is passed as unsigned int.
8
+ // Generic pointers can be converted to 32 bit shared memory pointers using __cvta_generic_to_shared.
9
+ // Only the 16 bit copy is exposed because 4 and 8 bit copies did not yield performance improvements.
10
+ template <int preload>
11
+ static __device__ __forceinline__ void cp_async_cg_16(const unsigned int dst, const void * src) {
12
+ static_assert(preload == 0 || preload == 64 || preload == 128 || preload == 256, "bad preload");
13
+ #ifdef CP_ASYNC_AVAILABLE
14
+ #if CUDART_VERSION >= 11040
15
+ if (preload == 256) {
16
+ asm volatile("cp.async.cg.shared.global.L2::256B [%0], [%1], 16;"
17
+ : : "r"(dst), "l"(src));
18
+ } else if (preload == 128) {
19
+ asm volatile("cp.async.cg.shared.global.L2::128B [%0], [%1], 16;"
20
+ : : "r"(dst), "l"(src));
21
+ } else if (preload == 64) {
22
+ asm volatile("cp.async.cg.shared.global.L2::64B [%0], [%1], 16;"
23
+ : : "r"(dst), "l"(src));
24
+ } else
25
+ #endif // CUDART_VERSION >= 11040
26
+ {
27
+ asm volatile("cp.async.cg.shared.global.L2 [%0], [%1], 16;"
28
+ : : "r"(dst), "l"(src));
29
+ }
30
+ #else
31
+ GGML_UNUSED(dst);
32
+ GGML_UNUSED(src);
33
+ NO_DEVICE_CODE;
34
+ #endif // CP_ASYNC_AVAILABLE
35
+ }
36
+
37
+ // Makes each thread wait until its asynchronous data copies are done.
38
+ // This does NOT provide any additional synchronization.
39
+ // In particular, when copying data with multiple warps a call to __syncthreads will be needed.
40
+ static __device__ __forceinline__ void cp_async_wait_all() {
41
+ #ifdef CP_ASYNC_AVAILABLE
42
+ asm volatile("cp.async.wait_all;");
43
+ #else
44
+ NO_DEVICE_CODE;
45
+ #endif // CP_ASYNC_AVAILABLE
46
+ }
ggml/src/ggml-cuda/fattn-common.cuh CHANGED
@@ -716,7 +716,9 @@ void launch_fattn(
716
 
717
  ggml_cuda_pool & pool = ctx.pool();
718
  cudaStream_t main_stream = ctx.stream();
719
- const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
 
 
720
 
721
  ggml_cuda_pool_alloc<half> K_f16(pool);
722
  ggml_cuda_pool_alloc<half> V_f16(pool);
@@ -768,13 +770,14 @@ void launch_fattn(
768
  dim3 blocks_num;
769
  if (parallel_blocks == 0) {
770
  // For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
771
- const int tiles_nwaves = (ntiles_total - nsm - 1) / nsm;
772
- const bool tiles_inefficient = 3*nsm < 2*tiles_nwaves*ntiles_total;
773
- const bool short_context = K->ne[1] < 4096;
774
 
775
  const int nblocks_stream_k = 2*nsm;
776
 
777
- blocks_num.x = short_context && !tiles_inefficient ? ntiles_total : nblocks_stream_k;
 
 
778
  blocks_num.y = 1;
779
  blocks_num.z = 1;
780
 
@@ -827,7 +830,7 @@ void launch_fattn(
827
  CUDA_CHECK(cudaGetLastError());
828
 
829
  if constexpr (parallel_blocks == 0) {
830
- if (blocks_num.x % ntiles_total != 0) { // Fixup is only needed if the SMs work on fractional tiles.
831
  const dim3 block_dim_combine(D, 1, 1);
832
  const dim3 blocks_num_combine = blocks_num;
833
 
 
716
 
717
  ggml_cuda_pool & pool = ctx.pool();
718
  cudaStream_t main_stream = ctx.stream();
719
+ const int id = ggml_cuda_get_device();
720
+ const int cc = ggml_cuda_info().devices[id].cc;
721
+ const int nsm = ggml_cuda_info().devices[id].nsm;
722
 
723
  ggml_cuda_pool_alloc<half> K_f16(pool);
724
  ggml_cuda_pool_alloc<half> V_f16(pool);
 
770
  dim3 blocks_num;
771
  if (parallel_blocks == 0) {
772
  // For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
773
+ const int tiles_nwaves = (ntiles_total + 2*nsm - 1) / (2*nsm);
774
+ const int tiles_efficiency_percent = 100 * ntiles_total / (2*nsm*tiles_nwaves);
 
775
 
776
  const int nblocks_stream_k = 2*nsm;
777
 
778
+ const bool use_stream_k = tiles_efficiency_percent < 75 || cc >= GGML_CUDA_CC_ADA_LOVELACE;
779
+
780
+ blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_total;
781
  blocks_num.y = 1;
782
  blocks_num.z = 1;
783
 
 
830
  CUDA_CHECK(cudaGetLastError());
831
 
832
  if constexpr (parallel_blocks == 0) {
833
+ if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
834
  const dim3 block_dim_combine(D, 1, 1);
835
  const dim3 blocks_num_combine = blocks_num;
836
 
ggml/src/ggml-cuda/fattn-mma-f16.cuh CHANGED
@@ -1,242 +1,195 @@
1
  #include "common.cuh"
 
2
  #include "mma.cuh"
3
  #include "fattn-common.cuh"
4
 
5
- template<int D, int ncols, int nwarps, int KQ_stride, bool use_logit_softcap, bool needs_fixup, bool is_fixup>
6
- static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
7
- const float2 * const __restrict__ Q_f2,
8
- const half2 * const __restrict__ K_h2,
9
- const half2 * const __restrict__ V_h2,
10
- const half * const __restrict__ maskh,
11
- float2 * const __restrict__ dstk,
12
- float2 * const __restrict__ dstk_fixup,
13
- const float scale,
14
- const float slope,
15
- const float logit_softcap,
16
- const int ne00,
17
- const int ne01,
18
- const int ne02,
19
- const int ne03,
20
- const int ne10,
21
- const int ne11,
22
- const int ne12,
23
- const int ne13,
24
- const int ne31,
25
- const int nb31,
26
- const int nb01,
27
- const int nb02,
28
- const int nb03,
29
- const int nb11,
30
- const int nb12,
31
- const int nb13,
32
- const int nb21,
33
- const int nb22,
34
- const int nb23,
35
- const int ne0,
36
- const int ne1,
37
- const int ne2,
38
- const int ne3,
39
- const int jt,
40
- const int kb0_start,
41
- const int kb0_stop) {
42
- #ifdef NEW_MMA_AVAILABLE
43
- //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
44
 
45
- typedef mma_A_I16K8<half2> mma_A;
46
- typedef mma_B_J8K8<half2> mma_B;
47
- typedef mma_C_I16J8<float> mma_C_KQ;
48
- typedef mma_C_I16J8<half2> mma_C_VKQ;
49
-
50
- static_assert(nwarps*mma_B::J % ncols == 0, "bad nwarps");
51
- constexpr int np = nwarps*mma_B::J / ncols; // Number of parallel CUDA warps per Q column.
52
-
53
- static_assert(D % nwarps == 0, "bad D");
54
- static_assert(KQ_stride % nwarps == 0, "bad KQ_stride");
55
 
 
 
 
56
  constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts.
57
- extern __shared__ half2 tile_KV[]; // Temporary shared buffer for loading K/V data with KQ_stride*D logical elements.
58
 
59
- const int stride_Q = nb01 / sizeof(float2);
60
- const int stride_KV = nb11 / sizeof(half2);
61
- const int stride_mask = nb31 / sizeof(half);
 
62
 
63
- mma_B Q_B[D/(2*mma_B::K)];
64
- mma_C_VKQ VKQ_C[D/mma_C_VKQ::I];
65
 
66
- float2 KQ_rowsum = {0.0f, 0.0f};
67
- float2 KQ_max = {-FLT_MAX/2.0f, -FLT_MAX/2.0f};
68
- float2 KQ_max_scale = {0.0f, 0.0f};
 
 
 
 
 
69
 
70
- // Temporarily load Q data into tile_KV, will be loaded into registers afterwards.
71
- // The loading is done with decreasing granularity for D for better memory bandwidth.
72
- const half2 scale_h2 = make_half2(scale, scale);
 
 
 
 
 
 
 
73
  #pragma unroll
74
  for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
75
- const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k);
76
- const int k0_stop = D/2 - (D/2) % (1*stride_k);
77
- const int stride_j = WARP_SIZE / stride_k;
78
 
79
- if (nwarps*stride_j > ncols && threadIdx.y*stride_j >= ncols) {
80
- break;
81
  }
82
 
83
  #pragma unroll
84
- for (int j0 = 0; j0 < ncols; j0 += nwarps*stride_j) {
85
- const int j = j0 + threadIdx.y*stride_j + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
86
 
87
- if (jt*ncols + j < ne01) {
88
- #pragma unroll
89
- for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
90
- const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
91
-
92
- const float2 tmp = Q_f2[(jt*ncols + j)*stride_Q + k];
93
- tile_KV[j*D2_padded + k] = scale_h2 * make_half2(tmp.x, tmp.y);
94
- }
95
- } else {
96
  #pragma unroll
97
- for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
98
- const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
99
 
100
- tile_KV[j*D2_padded + k] = make_half2(0.0f, 0.0f);
101
- }
102
  }
103
  }
104
  }
 
105
 
106
- __syncthreads();
107
-
108
- {
109
- const int j0 = (threadIdx.y / np) * mma_B::J;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
- #pragma unroll
112
- for (int k0 = 0; k0 < D/2; k0 += mma_B::K) {
113
- Q_B[k0/mma_B::K].load_ldmatrix(tile_KV + j0*D2_padded + k0, D2_padded);
114
- }
115
- }
116
 
 
 
117
  __syncthreads();
 
 
 
 
 
118
 
119
- // Iterate over ne11 == previous tokens:
120
- for (int kb0 = kb0_start; kb0 < kb0_stop; ++kb0) {
121
- const int k_VKQ_0 = kb0*KQ_stride;
122
- mma_C_KQ KQ_C[KQ_stride/(np*mma_C_KQ::I)];
123
-
124
- // Load K data into tile with decreasing granularity for D for better memory bandwidth:
125
- static_assert(KQ_stride % (4*nwarps) == 0, "out of bounds");
126
- #pragma unroll
127
- for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
128
- const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k);
129
- const int k0_stop = D/2 - (D/2) % (1*stride_k);
130
- const int stride_i = WARP_SIZE / stride_k;
131
-
132
- #pragma unroll
133
- for (int i_KQ_0 = 0; i_KQ_0 < KQ_stride; i_KQ_0 += nwarps*stride_i) {
134
- const int i_KQ = i_KQ_0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
135
-
136
- #pragma unroll
137
- for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += stride_k) {
138
- const int k_KQ = k_KQ_0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
139
-
140
- tile_KV[i_KQ*D2_padded + k_KQ] = K_h2[(k_VKQ_0 + i_KQ)*stride_KV + k_KQ];
141
- }
142
- }
143
- }
144
-
145
- __syncthreads();
146
-
147
- // Calculate tile of KQ:
148
  #pragma unroll
149
- for (int i_KQ_00 = 0; i_KQ_00 < KQ_stride; i_KQ_00 += np*mma_A::I) {
150
- const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*mma_A::I;
151
  #pragma unroll
152
- for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += mma_A::K) {
153
- mma_A K_A;
154
- K_A.load_ldmatrix(tile_KV + i_KQ_0*D2_padded + k_KQ_0, D2_padded);
155
- KQ_C[i_KQ_00/(np*mma_A::I)].mma(K_A, Q_B[k_KQ_0/mma_A::K]);
156
- }
157
  }
 
158
 
159
- __syncthreads();
 
 
160
 
161
- if (use_logit_softcap) {
162
- static_assert(KQ_stride % (np*mma_C_KQ::I) == 0, "bad loop size");
163
  #pragma unroll
164
- for (int i = 0; i < KQ_stride/(np*mma_C_KQ::I); ++i) {
165
  #pragma unroll
166
- for (int l = 0; l < mma_C_KQ::ne; ++l) {
167
- KQ_C[i].x[l] = logit_softcap*tanhf(KQ_C[i].x[l]);
168
- }
169
  }
170
  }
 
171
 
172
- if (maskh) {
173
- static_assert(KQ_stride % (np *mma_C_KQ::I) == 0, "bad loop size");
174
- static_assert(ncols % (nwarps/np*mma_C_KQ::J) == 0, "bad loop size");
175
  #pragma unroll
176
- for (int i00 = 0; i00 < KQ_stride; i00 += np*mma_C_KQ::I) {
177
- const int i0 = i00 + (threadIdx.y % np)*mma_C_KQ::I;
178
  #pragma unroll
179
- for (int l = 0; l < mma_C_KQ::ne; ++l) {
180
- const int i = i0 + mma_C_KQ::get_i(l);
181
- const int j = (threadIdx.y / np)*mma_C_KQ::J + mma_C_KQ::get_j(l);
182
 
183
- KQ_C[i00/(np*mma_C_KQ::I)].x[l] += slope*__half2float(maskh[j*stride_mask + k_VKQ_0 + i]);
184
- }
185
  }
186
  }
 
187
 
188
- // Calculate softmax for each KQ column using the current max. value.
189
- // The divisor is stored in KQ_rowsum and will be applied at the end.
190
- float2 KQ_max_new = KQ_max;
191
- static_assert(KQ_stride % (np*mma_C_KQ::I) == 0, "bad loop size");
192
  #pragma unroll
193
- for (int k = 0; k < KQ_stride/(np*mma_C_KQ::I); ++k) {
194
  #pragma unroll
195
- for (int l0 = 0; l0 < mma_C_KQ::ne; l0 += 2) {
196
- KQ_max_new.x = fmaxf(KQ_max_new.x, KQ_C[k].x[l0 + 0]);
197
- KQ_max_new.y = fmaxf(KQ_max_new.y, KQ_C[k].x[l0 + 1]);
198
- }
199
  }
 
200
 
201
- // Values per KQ column are spread across 8 threads, does not need full warp reduce:
202
  #pragma unroll
203
- for (int offset = 16; offset > 2; offset >>= 1) {
204
- KQ_max_new.x = fmaxf(KQ_max_new.x, __shfl_xor_sync(0xFFFFFFFF, KQ_max_new.x, offset, WARP_SIZE));
205
- KQ_max_new.y = fmaxf(KQ_max_new.y, __shfl_xor_sync(0xFFFFFFFF, KQ_max_new.y, offset, WARP_SIZE));
206
- }
207
-
208
- {
209
- const float2 diff = make_float2(KQ_max.x - KQ_max_new.x, KQ_max.y - KQ_max_new.y);
210
- KQ_max_scale = make_float2(expf(diff.x), expf(diff.y));
211
- if (diff.x <= SOFTMAX_FTZ_THRESHOLD) {
212
- KQ_max_scale.x = 0.0f;
213
- }
214
- if (diff.y <= SOFTMAX_FTZ_THRESHOLD) {
215
- KQ_max_scale.y = 0.0f;
216
- }
217
- KQ_max = KQ_max_new;
218
- }
219
 
220
- float2 KQ_rowsum_add = make_float2(0.0f, 0.0f);
221
- static_assert(KQ_stride % (np*mma_C_KQ::I) == 0, "bad loop size");
222
  #pragma unroll
223
- for (int k = 0; k < KQ_stride/(np*mma_C_KQ::I); ++k) {
224
  #pragma unroll
225
- for (int l = 0; l < mma_C_KQ::ne; ++l) {
226
- const float KQ_max_l = l % 2 == 0 ? KQ_max.x : KQ_max.y;
227
- const float diff = KQ_C[k].x[l] - KQ_max_l;
228
- KQ_C[k].x[l] = expf(diff);
229
- if (diff <= SOFTMAX_FTZ_THRESHOLD) {
230
- KQ_C[k].x[l] = 0.0f;
231
- }
232
 
233
- if (l % 2 == 0) {
234
- KQ_rowsum_add.x += KQ_C[k].x[l];
235
- } else {
236
- KQ_rowsum_add.y += KQ_C[k].x[l];
237
- }
238
  }
239
  }
 
 
 
 
 
 
240
 
241
  // Scale previous KQ_rowsum to account for a potential increase in KQ_max:
242
  KQ_rowsum.x = KQ_max_scale.x*KQ_rowsum.x + KQ_rowsum_add.x;
@@ -244,60 +197,179 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
244
 
245
  const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale.x, KQ_max_scale.y);
246
  #pragma unroll
247
- for (int i = 0; i < D/mma_C_VKQ::I; ++i) {
248
  #pragma unroll
249
- for (int l = 0; l < mma_C_VKQ::ne; ++l) {
250
  VKQ_C[i].x[l] *= KQ_max_scale_h2;
251
  }
252
  }
 
 
 
 
 
 
 
 
 
253
 
254
- // Convert KQ C tiles into B tiles for VKQ calculation:
255
- mma_B B[KQ_stride/(np*2*mma_B::K)];
256
- static_assert(KQ_stride % (np*2*mma_B::K) == 0, "bad loop size");
 
 
 
 
 
 
 
 
 
 
 
 
257
  #pragma unroll
258
- for (int k = 0; k < KQ_stride/(np*2*mma_B::K); ++k) {
259
- B[k] = KQ_C[k].to_mma_B();
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
  }
261
 
262
- // Load V data into tile with decreasing granularity for D for better memory bandwidth:
263
- static_assert(KQ_stride % (4*nwarps) == 0, "out of bounds");
264
  #pragma unroll
265
- for (int stride_i : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
266
- const int i0_start = stride_i == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_i);
267
- const int i0_stop = D/2 - (D/2) % (1*stride_i);
268
- const int stride_k = WARP_SIZE / stride_i;
269
 
 
270
  #pragma unroll
271
- for (int k_V_0 = 0; k_V_0 < KQ_stride; k_V_0 += nwarps*stride_k) {
272
- const int k_V = k_V_0 + threadIdx.y*stride_k + (stride_i == WARP_SIZE ? 0 : threadIdx.x / stride_i);
273
 
 
 
 
 
274
  #pragma unroll
275
- for (int i_V_0 = i0_start; i_V_0 < i0_stop; i_V_0 += stride_i) {
276
- const int i_V = i_V_0 + (stride_i == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_i);
277
 
278
- tile_KV[k_V*D2_padded + i_V] = V_h2[(k_VKQ_0 + k_V)*stride_KV + i_V];
279
  }
280
  }
281
  }
 
282
 
283
- __syncthreads();
284
 
285
- // Calculate VKQ tile:
286
- #pragma unroll
287
- for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += mma_C_VKQ::I) {
288
- static_assert((KQ_stride/2) % (np*mma_A::K) == 0, "bad loop size");
289
- #pragma unroll
290
- for (int k00 = 0; k00 < KQ_stride/2; k00 += np*mma_A::K) {
291
- const int k0 = k00 + (threadIdx.y % np)*mma_A::K;
292
 
293
- mma_A A;
294
- A.load_ldmatrix_trans(tile_KV + 2*k0*D2_padded + i_VKQ_0/2, D2_padded);
295
- VKQ_C[i_VKQ_0/mma_C_VKQ::I].mma(A, B[k00/(np*mma_A::K)]);
296
- }
297
  }
 
 
 
298
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
  __syncthreads();
300
  }
 
301
 
302
  // Finally, sum up partial KQ rowsums.
303
  // The partial sums are spread across 8 threads each, does not need full reduce.
@@ -310,26 +382,26 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
310
  // Write VKQ accumulators to shared memory in column-major format.
311
  // It's faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
312
  // Also for np > 1 the combination is done via these values in shared memory.
313
- const int j_cwd = threadIdx.y*mma_B::J + mma_B::get_j(-1); // j combine write data
314
  #pragma unroll
315
- for (int k0 = 0; k0 < D/2; k0 += mma_B::K) {
316
- const mma_B B = VKQ_C[k0/mma_B::K].to_mma_B(); // Conversion of C to B matrix puts it in column-major format.
317
 
318
  #pragma unroll
319
- for (int l = 0; l < mma_B::ne; ++l) {
320
- const int k = k0 + mma_B::get_k(l);
321
 
322
- tile_KV[j_cwd*D2_padded + k] = B.x[l];
323
  }
324
  }
325
 
326
- const int j_cwmo = (threadIdx.x % (2*mma_C_VKQ::J)) / mma_C_VKQ::J; // j combine write meta offset
327
- const int j_cwm = threadIdx.y*(2*mma_C_VKQ::J) + 2*mma_C_VKQ::get_j(-1) + j_cwmo; // j combine write meta
328
  const float2 KQ_cmr = make_float2(((const float *) &KQ_max)[j_cwmo], ((const float *) &KQ_rowsum)[j_cwmo]); // KQ combine max rowsum
329
 
330
- if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*mma_C_VKQ::J) {
331
  // Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale.
332
- ((float2 *) tile_KV)[j_cwm*(D2_padded/2) + D/4] = KQ_cmr;
333
  }
334
 
335
  __syncthreads();
@@ -337,11 +409,11 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
337
  static_assert(np == 1 || np == 2 || np == 4, "bad np");
338
  if (np == 1) {
339
  // No combination is needed, the meta data can be directly written from registers to VRAM.
340
- if (needs_fixup && threadIdx.x < mma_B::J) {
341
  float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
342
  dstk_fixup_meta[j_cwm] = KQ_cmr;
343
  }
344
- if (is_fixup && threadIdx.x < mma_B::J) {
345
  float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
346
  dstk_fixup_meta[j_cwm] = KQ_cmr;
347
  }
@@ -350,42 +422,40 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
350
  // Warps with threadIdx.y % np != 0 must NOT return early.
351
  // All threads must return simultaneously to avoid race conditions with work on the next tile.
352
 
353
- float * meta_j = (float *) tile_KV + (threadIdx.y*mma_B::J + threadIdx.x)*D2_padded + D/2;
354
 
355
  float KQ_cm = -FLT_MAX/2; // KQ combine max per parallel warp.
356
- if (np*mma_B::J == WARP_SIZE || threadIdx.x < np*mma_B::J) {
357
  KQ_cm = meta_j[0];
358
  }
359
 
360
  float KQ_cmn = KQ_cm; // KQ combine max new, max between all parallel warps.
361
  #pragma unroll
362
- for (int offset = np*mma_B::J/2; offset >= mma_B::J; offset >>= 1) {
363
  KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE));
364
  }
365
 
366
  const float KQ_cms = expf(KQ_cm - KQ_cmn); // KQ combine max scale per warp.
367
  float KQ_crs = 0.0f; // KQ combine rowsum, scaled sum of all parallel warps.
368
- if (np*mma_B::J == WARP_SIZE || threadIdx.x < np*mma_B::J) {
369
  KQ_crs = KQ_cms*meta_j[1];
370
  }
371
  #pragma unroll
372
- for (int offset = np*mma_B::J/2; offset >= mma_B::J; offset >>= 1) {
373
  KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE);
374
  }
375
 
376
  // Write back combined meta data:
377
- if (np*mma_B::J == WARP_SIZE || threadIdx.x < np*mma_B::J) {
378
- meta_j[0] = KQ_cmn; // Combined max. KQ values.
379
- meta_j[1] = KQ_crs; // Combined KQ rowsums.
380
- meta_j[2] = KQ_cms; // KQ max scales per parallel warp.
381
  }
382
- if (needs_fixup && threadIdx.x < mma_B::J) {
383
  float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
384
- dstk_fixup_meta[(threadIdx.y/np)*mma_B::J + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
385
  }
386
- if (is_fixup && threadIdx.x < mma_B::J) {
387
  float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
388
- dstk_fixup_meta[(threadIdx.y/np)*mma_B::J + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
389
  }
390
  }
391
 
@@ -404,6 +474,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
404
  const int k0_stop = D/2 - (D/2) % (1*stride_k);
405
  const int stride_j = WARP_SIZE / stride_k;
406
 
 
 
 
 
407
  if (nwarps*stride_j > ncols && threadIdx.y*stride_j >= ncols) {
408
  break;
409
  }
@@ -411,12 +485,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
411
  #pragma unroll
412
  for (int j0_dst = 0; j0_dst < ncols; j0_dst += (nwarps/np)*stride_j) {
413
  const int j_dst = j0_dst + (threadIdx.y/np)*stride_j + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
414
- const int j_tile_KV = (j_dst/mma_B::J)*(np*mma_B::J) + j_dst % mma_B::J;
415
 
416
  if (!is_fixup && jt*ncols + j_dst >= ne01) {
417
  continue;
418
  }
419
- const float * meta_j = (const float *) tile_KV + j_tile_KV*D2_padded + D/2;
420
  #pragma unroll
421
  for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
422
  const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
@@ -424,8 +498,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
424
  float2 dstk_val = make_float2(0.0f, 0.0f);
425
  #pragma unroll
426
  for (int ip = 0; ip < np; ++ip) {
427
- const float KQ_crs = np == 1 ? 1.0f : meta_j[ip*mma_B::J*D2_padded + 2];
428
- const float2 dstk_val_add = __half22float2(tile_KV[(j_tile_KV + ip*mma_B::J)*D2_padded + k]);
429
  dstk_val.x += dstk_val_add.x*KQ_crs;
430
  dstk_val.y += dstk_val_add.y*KQ_crs;
431
  }
@@ -450,7 +524,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
450
  __syncthreads();
451
  }
452
  #else
453
- NO_DEVICE_CODE;
454
  #endif // NEW_MMA_AVAILABLE
455
  }
456
 
@@ -494,6 +568,11 @@ static __global__ void flash_attn_ext_f16(
494
  const int ne1,
495
  const int ne2,
496
  const int ne3) {
 
 
 
 
 
497
  // Skip unused kernel variants for faster compilation:
498
  if (use_logit_softcap && !(D == 128 || D == 256)) {
499
  NO_DEVICE_CODE;
@@ -504,6 +583,10 @@ static __global__ void flash_attn_ext_f16(
504
 
505
  const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
506
 
 
 
 
 
507
  const int iter_k = ne11 / KQ_stride;
508
  const int iter_j = (ne01 + (ncols - 1)) / ncols;
509
 
@@ -535,14 +618,12 @@ static __global__ void flash_attn_ext_f16(
535
  constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
536
  flash_attn_ext_f16_process_tile<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup>
537
  (Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap,
538
- ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne31, nb31, nb01, nb02, nb03, nb11, nb12, nb13, nb21, nb22, nb23, ne0, ne1, ne2, ne3,
539
- jt, kb0_start, kb0_stop);
540
  } else {
541
  constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
542
  flash_attn_ext_f16_process_tile<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup>
543
  (Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap,
544
- ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne31, nb31, nb01, nb02, nb03, nb11, nb12, nb13, nb21, nb22, nb23, ne0, ne1, ne2, ne3,
545
- jt, kb0_start, kb0_stop);
546
  }
547
 
548
  kbc += iter_k;
@@ -571,24 +652,27 @@ static __global__ void flash_attn_ext_f16(
571
  constexpr bool needs_fixup = false;
572
  flash_attn_ext_f16_process_tile<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup>
573
  (Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap,
574
- ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne31, nb31, nb01, nb02, nb03, nb11, nb12, nb13, nb21, nb22, nb23, ne0, ne1, ne2, ne3,
575
- jt, kb0_start, kb0_stop);
576
  }
577
 
578
  template <int D, int cols_per_block>
579
  void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
580
- typedef mma_A_I16K8<half2> mma_A;
581
- typedef mma_B_J8K8<half2> mma_B;
582
 
583
- static_assert(D % mma_B::K == 0, "bad D");
584
- static_assert(cols_per_block % mma_B::J == 0, "bad cols_per_block");
585
 
586
  const ggml_tensor * KQV = dst;
 
 
 
 
 
587
 
588
- constexpr int KQ_stride = D <= 128 ? 64 : 32;
589
- constexpr int nwarps = (KQ_stride == 32 && cols_per_block <= 16) ?
590
- cols_per_block/mma_B::J * KQ_stride/mma_A::I : (cols_per_block <= 8 ? 4 : 8);
591
- constexpr size_t nbytes_shared = std::max(KQ_stride, nwarps*mma_B::J) * (D + 8) * sizeof(half);
592
 
593
  float logit_softcap;
594
  memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
 
1
  #include "common.cuh"
2
+ #include "cp-async.cuh"
3
  #include "mma.cuh"
4
  #include "fattn-common.cuh"
5
 
6
+ using namespace ggml_cuda_mma;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
+ typedef tile<16, 8, half2> tile_A;
9
+ typedef tile< 8, 8, half2> tile_B;
10
+ typedef tile<16, 8, float> tile_C_KQ;
11
+ typedef tile<16, 4, half2> tile_C_VKQ;
 
 
 
 
 
 
12
 
13
+ template<int D, int nwarps, int KQ_stride>
14
+ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
15
+ const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int stride_KV) {
16
  constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts.
 
17
 
18
+ // If cp.async is available, load up to the highest power of 2 in D asynchronously:
19
+ #ifdef CP_ASYNC_AVAILABLE
20
+ static_assert(D >= 64 && D < 512, "bad D");
21
+ constexpr int k0_sync_start = D/2 < 64 ? 32 : (D/2 < 128 ? 64 : 128);
22
 
23
+ const unsigned int tile_KV_32 = __cvta_generic_to_shared(tile_KV);
 
24
 
25
+ constexpr int preload = 64;
26
+ constexpr int h2_per_chunk = 16/sizeof(half2);
27
+ constexpr int chunks_per_row = k0_sync_start / h2_per_chunk;
28
+ constexpr int stride_i = WARP_SIZE / chunks_per_row;
29
+ #pragma unroll
30
+ for (int i0 = 0; i0 < KQ_stride; i0 += nwarps*stride_i) {
31
+ const int i = i0 + threadIdx.y*stride_i + (chunks_per_row == WARP_SIZE ? 0 : threadIdx.x / chunks_per_row);
32
+ const int k = (chunks_per_row == WARP_SIZE ? threadIdx.x : threadIdx.x % chunks_per_row)*h2_per_chunk;
33
 
34
+ cp_async_cg_16<preload>(tile_KV_32 + (i*D2_padded + k)*sizeof(half2), KV + i*stride_KV + k);
35
+ }
36
+ #else
37
+ constexpr int k0_sync_start = 0;
38
+ #endif // CP_ASYNC_AVAILABLE
39
+ static_assert(k0_sync_start % WARP_SIZE == 0, "bad k0_sync_start");
40
+
41
+ // If D is not a power of 2, the rest is loaded synchronously.
42
+ // K/V data is loaded with decreasing granularity for D for better memory bandwidth.
43
+ static_assert(KQ_stride % (4*nwarps) == 0, "out of bounds");
44
  #pragma unroll
45
  for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
46
+ const int k0_start = stride_k == WARP_SIZE ? k0_sync_start : D/2 - (D/2) % (2*stride_k);
47
+ const int k0_stop = D/2 - (D/2) % (1*stride_k);
48
+ const int stride_i = WARP_SIZE / stride_k;
49
 
50
+ if (k0_start == k0_stop || k0_stop <= k0_sync_start) {
51
+ continue;
52
  }
53
 
54
  #pragma unroll
55
+ for (int i0 = 0; i0 < KQ_stride; i0 += nwarps*stride_i) {
56
+ const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
57
 
 
 
 
 
 
 
 
 
 
58
  #pragma unroll
59
+ for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
60
+ const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
61
 
62
+ tile_KV[i*D2_padded + k] = KV[i*stride_KV + k];
 
63
  }
64
  }
65
  }
66
+ }
67
 
68
+ template<int D, int ncols, int nwarps, int KQ_stride, bool use_logit_softcap, bool needs_fixup, bool is_fixup, bool last_iter>
69
+ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
70
+ const float2 * const __restrict__ Q_f2,
71
+ const half2 * const __restrict__ K_h2,
72
+ const half2 * const __restrict__ V_h2,
73
+ const half * const __restrict__ maskh,
74
+ float2 * const __restrict__ dstk,
75
+ float2 * const __restrict__ dstk_fixup,
76
+ const float scale,
77
+ const float slope,
78
+ const float logit_softcap,
79
+ const int ne01,
80
+ const int ne02,
81
+ const int stride_Q,
82
+ const int stride_KV,
83
+ const int stride_mask,
84
+ const int jt,
85
+ half2 * const __restrict__ tile_K,
86
+ half2 * const __restrict__ tile_V,
87
+ const tile_B * const __restrict__ Q_B,
88
+ tile_C_VKQ * const __restrict__ VKQ_C,
89
+ float2 & KQ_max,
90
+ float2 & KQ_rowsum,
91
+ const int kb0) {
92
+ #ifdef NEW_MMA_AVAILABLE
93
+ constexpr int np = nwarps*tile_B::I / ncols; // Number of parallel CUDA warps per Q column.
94
+ constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts.
95
 
96
+ const int k_VKQ_0 = kb0*KQ_stride;
97
+ tile_C_KQ KQ_C[KQ_stride/(np*tile_C_KQ::I)];
 
 
 
98
 
99
+ #ifdef CP_ASYNC_AVAILABLE
100
+ cp_async_wait_all();
101
  __syncthreads();
102
+ flash_attn_ext_f16_load_tile<D, nwarps, KQ_stride>(V_h2 + k_VKQ_0*stride_KV, tile_V, stride_KV);
103
+ #else
104
+ flash_attn_ext_f16_load_tile<D, nwarps, KQ_stride>(K_h2 + k_VKQ_0*stride_KV, tile_K, stride_KV);
105
+ __syncthreads();
106
+ #endif // CP_ASYNC_AVAILABLE
107
 
108
+ // Calculate tile of KQ:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  #pragma unroll
110
+ for (int i_KQ_00 = 0; i_KQ_00 < KQ_stride; i_KQ_00 += np*tile_A::I) {
111
+ const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*tile_A::I;
112
  #pragma unroll
113
+ for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += tile_A::J) {
114
+ tile_A K_A;
115
+ load_ldmatrix(K_A, tile_K + i_KQ_0*D2_padded + k_KQ_0, D2_padded);
116
+ mma(KQ_C[i_KQ_00/(np*tile_A::I)], K_A, ((tile_B *) Q_B)[k_KQ_0/tile_A::J]);
 
117
  }
118
+ }
119
 
120
+ #ifndef CP_ASYNC_AVAILABLE
121
+ __syncthreads(); // Only needed if tile_K == tile_V.
122
+ #endif // CP_ASYNC_AVAILABLE
123
 
124
+ if (use_logit_softcap) {
125
+ static_assert(KQ_stride % (np*tile_C_KQ::I) == 0, "bad loop size");
126
  #pragma unroll
127
+ for (int i = 0; i < KQ_stride/(np*tile_C_KQ::I); ++i) {
128
  #pragma unroll
129
+ for (int l = 0; l < tile_C_KQ::ne; ++l) {
130
+ KQ_C[i].x[l] = logit_softcap*tanhf(KQ_C[i].x[l]);
 
131
  }
132
  }
133
+ }
134
 
135
+ if (maskh) {
136
+ static_assert(KQ_stride % (np *tile_C_KQ::I) == 0, "bad loop size");
137
+ static_assert(ncols % (nwarps/np*tile_C_KQ::J) == 0, "bad loop size");
138
  #pragma unroll
139
+ for (int i00 = 0; i00 < KQ_stride; i00 += np*tile_C_KQ::I) {
140
+ const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ::I;
141
  #pragma unroll
142
+ for (int l = 0; l < tile_C_KQ::ne; ++l) {
143
+ const int i = i0 + tile_C_KQ::get_i(l);
144
+ const int j = (threadIdx.y / np)*tile_C_KQ::J + tile_C_KQ::get_j(l);
145
 
146
+ KQ_C[i00/(np*tile_C_KQ::I)].x[l] += slope*__half2float(maskh[j*stride_mask + k_VKQ_0 + i]);
 
147
  }
148
  }
149
+ }
150
 
151
+ // Calculate softmax for each KQ column using the current max. value.
152
+ // The divisor is stored in KQ_rowsum and will be applied at the end.
153
+ float2 KQ_max_new = KQ_max;
154
+ static_assert(KQ_stride % (np*tile_C_KQ::I) == 0, "bad loop size");
155
  #pragma unroll
156
+ for (int k = 0; k < KQ_stride/(np*tile_C_KQ::I); ++k) {
157
  #pragma unroll
158
+ for (int l0 = 0; l0 < tile_C_KQ::ne; l0 += 2) {
159
+ KQ_max_new.x = fmaxf(KQ_max_new.x, KQ_C[k].x[l0 + 0]);
160
+ KQ_max_new.y = fmaxf(KQ_max_new.y, KQ_C[k].x[l0 + 1]);
 
161
  }
162
+ }
163
 
164
+ // Values per KQ column are spread across 8 threads, does not need full warp reduce:
165
  #pragma unroll
166
+ for (int offset = 16; offset > 2; offset >>= 1) {
167
+ KQ_max_new.x = fmaxf(KQ_max_new.x, __shfl_xor_sync(0xFFFFFFFF, KQ_max_new.x, offset, WARP_SIZE));
168
+ KQ_max_new.y = fmaxf(KQ_max_new.y, __shfl_xor_sync(0xFFFFFFFF, KQ_max_new.y, offset, WARP_SIZE));
169
+ }
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
+ float2 KQ_rowsum_add = make_float2(0.0f, 0.0f);
172
+ static_assert(KQ_stride % (np*tile_C_KQ::I) == 0, "bad loop size");
173
  #pragma unroll
174
+ for (int k = 0; k < KQ_stride/(np*tile_C_KQ::I); ++k) {
175
  #pragma unroll
176
+ for (int l = 0; l < tile_C_KQ::ne; ++l) {
177
+ const float KQ_max_l = l % 2 == 0 ? KQ_max_new.x : KQ_max_new.y;
178
+ const float diff = KQ_C[k].x[l] - KQ_max_l;
179
+ KQ_C[k].x[l] = expf(diff);
 
 
 
180
 
181
+ if (l % 2 == 0) {
182
+ KQ_rowsum_add.x += KQ_C[k].x[l];
183
+ } else {
184
+ KQ_rowsum_add.y += KQ_C[k].x[l];
 
185
  }
186
  }
187
+ }
188
+
189
+ {
190
+ const float2 diff = make_float2(KQ_max.x - KQ_max_new.x, KQ_max.y - KQ_max_new.y);
191
+ const float2 KQ_max_scale = make_float2(expf(diff.x), expf(diff.y));
192
+ KQ_max = KQ_max_new;
193
 
194
  // Scale previous KQ_rowsum to account for a potential increase in KQ_max:
195
  KQ_rowsum.x = KQ_max_scale.x*KQ_rowsum.x + KQ_rowsum_add.x;
 
197
 
198
  const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale.x, KQ_max_scale.y);
199
  #pragma unroll
200
+ for (int i = 0; i < D/tile_C_VKQ::I; ++i) {
201
  #pragma unroll
202
+ for (int l = 0; l < tile_C_VKQ::ne; ++l) {
203
  VKQ_C[i].x[l] *= KQ_max_scale_h2;
204
  }
205
  }
206
+ }
207
+
208
+ // Convert KQ C tiles into B tiles for VKQ calculation:
209
+ tile_B B[KQ_stride/(np*2*tile_B::J)];
210
+ static_assert(KQ_stride % (np*2*tile_B::J) == 0, "bad loop size");
211
+ #pragma unroll
212
+ for (int k = 0; k < KQ_stride/(np*2*tile_B::J); ++k) {
213
+ B[k] = get_transposed(get_half2(KQ_C[k]));
214
+ }
215
 
216
+ #ifdef CP_ASYNC_AVAILABLE
217
+ cp_async_wait_all();
218
+ __syncthreads();
219
+ if (!last_iter) {
220
+ flash_attn_ext_f16_load_tile<D, nwarps, KQ_stride>(K_h2 + (k_VKQ_0 + KQ_stride)*stride_KV, tile_K, stride_KV);
221
+ }
222
+ #else
223
+ flash_attn_ext_f16_load_tile<D, nwarps, KQ_stride>(V_h2 + k_VKQ_0*stride_KV, tile_V, stride_KV);
224
+ __syncthreads();
225
+ #endif // CP_ASYNC_AVAILABLE
226
+
227
+ // Calculate VKQ tile:
228
+ #pragma unroll
229
+ for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += tile_C_VKQ::I) {
230
+ static_assert((KQ_stride/2) % (np*tile_A::J) == 0, "bad loop size");
231
  #pragma unroll
232
+ for (int k00 = 0; k00 < KQ_stride/2; k00 += np*tile_A::J) {
233
+ const int k0 = k00 + (threadIdx.y % np)*tile_A::J;
234
+
235
+ tile_A A;
236
+ load_ldmatrix_trans(A, tile_V + 2*k0*D2_padded + i_VKQ_0/2, D2_padded);
237
+ mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]);
238
+ }
239
+ }
240
+
241
+ #ifndef CP_ASYNC_AVAILABLE
242
+ __syncthreads(); // Only needed if tile_K == tile_V.
243
+ #endif // CP_ASYNC_AVAILABLE
244
+
245
+ #else
246
+ NO_DEVICE_CODE;
247
+ #endif // NEW_MMA_AVAILABLE
248
+ }
249
+
250
+ template<int D, int ncols, int nwarps, int KQ_stride, bool use_logit_softcap, bool needs_fixup, bool is_fixup>
251
+ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
252
+ const float2 * const __restrict__ Q_f2,
253
+ const half2 * const __restrict__ K_h2,
254
+ const half2 * const __restrict__ V_h2,
255
+ const half * const __restrict__ maskh,
256
+ float2 * const __restrict__ dstk,
257
+ float2 * const __restrict__ dstk_fixup,
258
+ const float scale,
259
+ const float slope,
260
+ const float logit_softcap,
261
+ const int ne01,
262
+ const int ne02,
263
+ const int stride_Q,
264
+ const int stride_KV,
265
+ const int stride_mask,
266
+ const int jt,
267
+ const int kb0_start,
268
+ const int kb0_stop) {
269
+ #ifdef NEW_MMA_AVAILABLE
270
+ //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
271
+
272
+ static_assert(nwarps*tile_B::I % ncols == 0, "bad nwarps");
273
+ constexpr int np = nwarps*tile_B::I / ncols; // Number of parallel CUDA warps per Q column.
274
+
275
+ static_assert(D % nwarps == 0, "bad D");
276
+ static_assert(KQ_stride % nwarps == 0, "bad KQ_stride");
277
+
278
+ constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts.
279
+
280
+ // Temporary shared buffer for loading K/V data with KQ_stride*D logical elements:
281
+ extern __shared__ half2 tile_K[];
282
+ #ifdef CP_ASYNC_AVAILABLE
283
+ half2 * tile_V = tile_K + KQ_stride*D2_padded;
284
+ #else
285
+ half2 * tile_V = tile_K;
286
+ #endif // CP_ASYNC_AVAILABLE
287
+
288
+ tile_B Q_B[D/(2*tile_B::J)];
289
+ tile_C_VKQ VKQ_C[D/tile_C_VKQ::I];
290
+
291
+ float2 KQ_rowsum = {0.0f, 0.0f};
292
+ float2 KQ_max = {-FLT_MAX/2.0f, -FLT_MAX/2.0f};
293
+
294
+ // Temporarily load Q data into tile_K, will be loaded into registers afterwards.
295
+ // The loading is done with decreasing granularity for D for better memory bandwidth.
296
+ const half2 scale_h2 = make_half2(scale, scale);
297
+ #pragma unroll
298
+ for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
299
+ const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k);
300
+ const int k0_stop = D/2 - (D/2) % (1*stride_k);
301
+ const int stride_j = WARP_SIZE / stride_k;
302
+
303
+ if (k0_start == k0_stop) {
304
+ continue;
305
+ }
306
+
307
+ if (nwarps*stride_j > ncols && threadIdx.y*stride_j >= ncols) {
308
+ break;
309
  }
310
 
 
 
311
  #pragma unroll
312
+ for (int j0 = 0; j0 < ncols; j0 += nwarps*stride_j) {
313
+ const int j = j0 + threadIdx.y*stride_j + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
 
 
314
 
315
+ if (jt*ncols + j < ne01) {
316
  #pragma unroll
317
+ for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
318
+ const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
319
 
320
+ const float2 tmp = Q_f2[(jt*ncols + j)*stride_Q + k];
321
+ tile_K[j*D2_padded + k] = scale_h2 * make_half2(tmp.x, tmp.y);
322
+ }
323
+ } else {
324
  #pragma unroll
325
+ for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
326
+ const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
327
 
328
+ tile_K[j*D2_padded + k] = make_half2(0.0f, 0.0f);
329
  }
330
  }
331
  }
332
+ }
333
 
334
+ __syncthreads();
335
 
336
+ {
337
+ const int j0 = (threadIdx.y / np) * tile_B::I;
 
 
 
 
 
338
 
339
+ #pragma unroll
340
+ for (int k0 = 0; k0 < D/2; k0 += tile_B::J) {
341
+ load_ldmatrix(Q_B[k0/tile_B::J], tile_K + j0*D2_padded + k0, D2_padded);
 
342
  }
343
+ }
344
+
345
+ __syncthreads();
346
 
347
+ // Preload K data for first iteration when using cp_async:
348
+ #ifdef CP_ASYNC_AVAILABLE
349
+ flash_attn_ext_f16_load_tile<D, nwarps, KQ_stride>(K_h2 + kb0_start*KQ_stride*stride_KV, tile_K, stride_KV);
350
+ #endif // CP_ASYNC_AVAILABLE
351
+
352
+ // Iterate over ne11 == previous tokens:
353
+ for (int kb0 = kb0_start; kb0 < kb0_stop-1; ++kb0) {
354
+ constexpr bool last_iter = false;
355
+ flash_attn_ext_f16_iter<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup, last_iter>
356
+ (Q_f2, K_h2, V_h2, maskh, dstk, dstk_fixup, scale, slope, logit_softcap,
357
+ ne01, ne02, stride_Q, stride_KV, stride_mask, jt, tile_K, tile_V, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
358
+ }
359
+ { // kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
360
+ constexpr bool last_iter = true;
361
+ flash_attn_ext_f16_iter<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup, last_iter>
362
+ (Q_f2, K_h2, V_h2, maskh, dstk, dstk_fixup, scale, slope, logit_softcap,
363
+ ne01, ne02, stride_Q, stride_KV, stride_mask, jt, tile_K, tile_V, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1);
364
+ }
365
+
366
+ // With cp_async there is no __syncthreads at the end of the iter,
367
+ // there can be a race condition on shared memory access for combining/writing back results.
368
+ #ifdef CP_ASYNC_AVAILABLE
369
+ if (nwarps*tile_B::I > KQ_stride) {
370
  __syncthreads();
371
  }
372
+ #endif // CP_ASYNC_AVAILABLE
373
 
374
  // Finally, sum up partial KQ rowsums.
375
  // The partial sums are spread across 8 threads each, does not need full reduce.
 
382
  // Write VKQ accumulators to shared memory in column-major format.
383
  // It's faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
384
  // Also for np > 1 the combination is done via these values in shared memory.
385
+ const int j_cwd = threadIdx.y*tile_B::I + tile_B::get_i(-1); // j combine write data
386
  #pragma unroll
387
+ for (int k0 = 0; k0 < D/2; k0 += tile_B::J) {
388
+ const tile_B B = get_transposed(VKQ_C[k0/tile_B::J]); // Conversion of C to B matrix puts it in column-major format.
389
 
390
  #pragma unroll
391
+ for (int l = 0; l < tile_B::ne; ++l) {
392
+ const int k = k0 + tile_B::get_j(l);
393
 
394
+ tile_K[j_cwd*D2_padded + k] = B.x[l];
395
  }
396
  }
397
 
398
+ const int j_cwmo = (threadIdx.x % (2*tile_C_VKQ::J)) / tile_C_VKQ::J; // j combine write meta offset
399
+ const int j_cwm = threadIdx.y*(2*tile_C_VKQ::J) + 2*tile_C_VKQ::get_j(-1) + j_cwmo; // j combine write meta
400
  const float2 KQ_cmr = make_float2(((const float *) &KQ_max)[j_cwmo], ((const float *) &KQ_rowsum)[j_cwmo]); // KQ combine max rowsum
401
 
402
+ if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*tile_C_VKQ::J) {
403
  // Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale.
404
+ ((float2 *) tile_K)[j_cwm*(D2_padded/2) + D/4] = KQ_cmr;
405
  }
406
 
407
  __syncthreads();
 
409
  static_assert(np == 1 || np == 2 || np == 4, "bad np");
410
  if (np == 1) {
411
  // No combination is needed, the meta data can be directly written from registers to VRAM.
412
+ if (needs_fixup && threadIdx.x < tile_B::I) {
413
  float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
414
  dstk_fixup_meta[j_cwm] = KQ_cmr;
415
  }
416
+ if (is_fixup && threadIdx.x < tile_B::I) {
417
  float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
418
  dstk_fixup_meta[j_cwm] = KQ_cmr;
419
  }
 
422
  // Warps with threadIdx.y % np != 0 must NOT return early.
423
  // All threads must return simultaneously to avoid race conditions with work on the next tile.
424
 
425
+ float * meta_j = (float *) tile_K + (threadIdx.y*tile_B::I + threadIdx.x)*D2_padded + D/2;
426
 
427
  float KQ_cm = -FLT_MAX/2; // KQ combine max per parallel warp.
428
+ if (np*tile_B::I == WARP_SIZE || threadIdx.x < np*tile_B::I) {
429
  KQ_cm = meta_j[0];
430
  }
431
 
432
  float KQ_cmn = KQ_cm; // KQ combine max new, max between all parallel warps.
433
  #pragma unroll
434
+ for (int offset = np*tile_B::I/2; offset >= tile_B::I; offset >>= 1) {
435
  KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE));
436
  }
437
 
438
  const float KQ_cms = expf(KQ_cm - KQ_cmn); // KQ combine max scale per warp.
439
  float KQ_crs = 0.0f; // KQ combine rowsum, scaled sum of all parallel warps.
440
+ if (np*tile_B::I == WARP_SIZE || threadIdx.x < np*tile_B::I) {
441
  KQ_crs = KQ_cms*meta_j[1];
442
  }
443
  #pragma unroll
444
+ for (int offset = np*tile_B::I/2; offset >= tile_B::I; offset >>= 1) {
445
  KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE);
446
  }
447
 
448
  // Write back combined meta data:
449
+ if (np*tile_B::I == WARP_SIZE || threadIdx.x < np*tile_B::I) {
450
+ *((float2 *) meta_j) = make_float2(KQ_cms, KQ_crs); // Combined KQ max scale + rowsum.
 
 
451
  }
452
+ if (needs_fixup && threadIdx.x < tile_B::I) {
453
  float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
454
+ dstk_fixup_meta[(threadIdx.y/np)*tile_B::I + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
455
  }
456
+ if (is_fixup && threadIdx.x < tile_B::I) {
457
  float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
458
+ dstk_fixup_meta[(threadIdx.y/np)*tile_B::I + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
459
  }
460
  }
461
 
 
474
  const int k0_stop = D/2 - (D/2) % (1*stride_k);
475
  const int stride_j = WARP_SIZE / stride_k;
476
 
477
+ if (k0_start == k0_stop) {
478
+ continue;
479
+ }
480
+
481
  if (nwarps*stride_j > ncols && threadIdx.y*stride_j >= ncols) {
482
  break;
483
  }
 
485
  #pragma unroll
486
  for (int j0_dst = 0; j0_dst < ncols; j0_dst += (nwarps/np)*stride_j) {
487
  const int j_dst = j0_dst + (threadIdx.y/np)*stride_j + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
488
+ const int j_tile_K = (j_dst/tile_B::I)*(np*tile_B::I) + j_dst % tile_B::I;
489
 
490
  if (!is_fixup && jt*ncols + j_dst >= ne01) {
491
  continue;
492
  }
493
+ const float * meta_j = (const float *) tile_K + j_tile_K*D2_padded + D/2;
494
  #pragma unroll
495
  for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
496
  const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
 
498
  float2 dstk_val = make_float2(0.0f, 0.0f);
499
  #pragma unroll
500
  for (int ip = 0; ip < np; ++ip) {
501
+ const float KQ_crs = np == 1 ? 1.0f : meta_j[ip*tile_B::I*D2_padded + 0];
502
+ const float2 dstk_val_add = __half22float2(tile_K[(j_tile_K + ip*tile_B::I)*D2_padded + k]);
503
  dstk_val.x += dstk_val_add.x*KQ_crs;
504
  dstk_val.y += dstk_val_add.y*KQ_crs;
505
  }
 
524
  __syncthreads();
525
  }
526
  #else
527
+ NO_DEVICE_CODE;
528
  #endif // NEW_MMA_AVAILABLE
529
  }
530
 
 
568
  const int ne1,
569
  const int ne2,
570
  const int ne3) {
571
+ #ifndef NEW_MMA_AVAILABLE
572
+ NO_DEVICE_CODE;
573
+ return;
574
+ #endif // NEW_MMA_AVAILABLE
575
+
576
  // Skip unused kernel variants for faster compilation:
577
  if (use_logit_softcap && !(D == 128 || D == 256)) {
578
  NO_DEVICE_CODE;
 
583
 
584
  const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
585
 
586
+ const int stride_Q = nb01 / sizeof(float2);
587
+ const int stride_KV = nb11 / sizeof(half2);
588
+ const int stride_mask = nb31 / sizeof(half);
589
+
590
  const int iter_k = ne11 / KQ_stride;
591
  const int iter_j = (ne01 + (ncols - 1)) / ncols;
592
 
 
618
  constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
619
  flash_attn_ext_f16_process_tile<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup>
620
  (Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap,
621
+ ne01, ne02, stride_Q, stride_KV, stride_mask, jt, kb0_start, kb0_stop);
 
622
  } else {
623
  constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
624
  flash_attn_ext_f16_process_tile<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup>
625
  (Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap,
626
+ ne01, ne02, stride_Q, stride_KV, stride_mask, jt, kb0_start, kb0_stop);
 
627
  }
628
 
629
  kbc += iter_k;
 
652
  constexpr bool needs_fixup = false;
653
  flash_attn_ext_f16_process_tile<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup>
654
  (Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap,
655
+ ne01, ne02, stride_Q, stride_KV, stride_mask, jt, kb0_start, kb0_stop);
 
656
  }
657
 
658
  template <int D, int cols_per_block>
659
  void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
660
+ typedef tile<16, 8, half2> tile_A;
661
+ typedef tile< 8, 8, half2> tile_B;
662
 
663
+ static_assert(D % tile_B::J == 0, "bad D");
664
+ static_assert(cols_per_block % tile_B::I == 0, "bad cols_per_block");
665
 
666
  const ggml_tensor * KQV = dst;
667
+ const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
668
+
669
+ constexpr int KQ_stride = D <= 128 ? 64 : 32;
670
+ constexpr int nwarps = (KQ_stride == 32 && cols_per_block <= 16) ?
671
+ cols_per_block/tile_B::J * KQ_stride/tile_A::I : (cols_per_block <= 8 ? 4 : 8);
672
 
673
+ const int nrows_KQ = cp_async_available(cc) ? 2*KQ_stride : KQ_stride;
674
+ const int nrows_combine = nwarps*tile_B::J;
675
+ const size_t nbytes_shared = std::max(nrows_KQ, nrows_combine) * (D + 8) * sizeof(half);
 
676
 
677
  float logit_softcap;
678
  memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
ggml/src/ggml-cuda/mma.cuh CHANGED
@@ -4,11 +4,12 @@
4
  // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-multiply-accumulate-operation-using-mma-instruction
5
  //
6
  // Like with nvcuda::wmma there are three types of matrix tiles: A, B, and C with A @ B = C.
7
- // A is a row-major matrix with shape I x K.
8
- // B is a column-major matrix with shape K x J.
9
- // C is a column-major matrix with shape I x J.
10
- // Note that along their lowest dimension I, J, and K are measured in physical 32 bit elements instead of logical elements.
11
- // The functions get_i, get_j, and get_k can be used to get the physical 32 bit index of the lth element of a thread within a tile.
 
12
  // All matrix tiles have ne physical 32 bit elements per warp.
13
  //
14
  // As described in the documentation, all pointers for load_ldmatrix must be to shared memory and aligned to 16 bytes.
@@ -23,7 +24,7 @@ static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) {
23
 
24
  #ifdef NEW_MMA_AVAILABLE
25
  asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
26
- : "+r"(ret) : "r"(x));
27
  #else
28
  NO_DEVICE_CODE;
29
  #endif // defined(NEW_MMA_AVAILABLE)
@@ -52,407 +53,267 @@ static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) {
52
 
53
  #endif // CUDART_VERSION >= 11080
54
 
 
 
 
 
 
55
 
56
- template <typename T>
57
- struct mma_A_I16K4 {
58
- static_assert(sizeof(T) == 4, "bad type size");
59
-
60
- static constexpr int I = 16;
61
- static constexpr int K = 4;
62
- static constexpr int ne = 2;
63
-
64
- T x[ne];
 
 
 
 
 
 
 
 
 
65
 
66
- static __device__ __forceinline__ int get_i(const int l) {
67
- const int ret = (l%2) * (I/2) + threadIdx.x / K;
68
- GGML_CUDA_ASSUME(ret >= 0);
69
- GGML_CUDA_ASSUME(ret < I);
70
- return ret;
71
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
- static __device__ __forceinline__ int get_k(const int /* l */) {
74
- const int ret = threadIdx.x % K;
75
- GGML_CUDA_ASSUME(ret >= 0);
76
- GGML_CUDA_ASSUME(ret < K);
77
- return ret;
78
- }
 
 
 
 
 
 
79
 
80
- __device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) {
 
 
81
  #pragma unroll
82
- for (int l = 0; l < ne; ++l) {
83
- x[l] = xs0[get_i(l)*stride + get_k(l)];
84
  }
85
- }
86
-
87
- __device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) {
88
- #ifdef NEW_MMA_AVAILABLE
89
- int * xi = (int *) x;
90
- const int * xs = (const int *) xs0 + (threadIdx.x%I)*stride;
91
- asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
92
- : "+r"(xi[0]), "+r"(xi[1])
93
- : "l"(xs));
94
- #else
95
- load_generic(xs0, stride);
96
- #endif // NEW_MMA_AVAILABLE
97
- }
98
- };
99
-
100
- template <typename T>
101
- struct mma_A_I16K8 {
102
- static_assert(sizeof(T) == 4, "bad type size");
103
-
104
- static constexpr int I = 16;
105
- static constexpr int K = 8;
106
- static constexpr int ne = 4;
107
-
108
- T x[ne];
109
-
110
- static __device__ __forceinline__ int get_i(const int l) {
111
- const int ret = (l%2) * (I/2) + threadIdx.x / (K/2);
112
- GGML_CUDA_ASSUME(ret >= 0);
113
- GGML_CUDA_ASSUME(ret < I);
114
  return ret;
115
  }
116
 
117
- static __device__ __forceinline__ int get_k(const int l) {
118
- const int ret = (l/2) * (K/2) + threadIdx.x % (K/2);
119
- GGML_CUDA_ASSUME(ret >= 0);
120
- GGML_CUDA_ASSUME(ret < K);
 
121
  return ret;
122
  }
123
 
124
- __device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) {
 
125
  #pragma unroll
126
- for (int l = 0; l < ne; ++l) {
127
- x[l] = xs0[get_i(l)*stride + get_k(l)];
128
  }
129
  }
130
 
131
- __device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) {
 
 
132
  #ifdef NEW_MMA_AVAILABLE
133
- int * xi = (int * ) x;
134
- const int * xs = (const int *) xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2);
135
- asm("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
136
- : "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3])
137
  : "l"(xs));
138
  #else
139
- GGML_UNUSED(xs0);
140
- GGML_UNUSED(stride);
141
- NO_DEVICE_CODE;
142
  #endif // NEW_MMA_AVAILABLE
143
  }
144
 
145
- __device__ __forceinline__ void load_ldmatrix_trans(const T * __restrict__ xs0, const int & stride) {
 
 
146
  #ifdef NEW_MMA_AVAILABLE
147
- int * xi = (int * ) x;
148
- const int * xs = (const int *) xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2);
149
- asm("ldmatrix.sync.aligned.m8n8.x4.trans.b16 {%0, %1, %2, %3}, [%4];"
150
- : "+r"(xi[0]), "+r"(xi[2]), "+r"(xi[1]), "+r"(xi[3])
151
  : "l"(xs));
152
  #else
153
- GGML_UNUSED(xs0);
154
- GGML_UNUSED(stride);
155
- NO_DEVICE_CODE;
156
  #endif // NEW_MMA_AVAILABLE
157
  }
158
 
159
- __device__ __forceinline__ void transpose() {
160
- int * xi = (int *) x;
161
- xi[0] = ggml_cuda_movmatrix(xi[0]);
162
-
163
- const int tmp = ggml_cuda_movmatrix(xi[1]);
164
- xi[1] = ggml_cuda_movmatrix(xi[2]);
165
- xi[2] = tmp;
166
-
167
- xi[3] = ggml_cuda_movmatrix(xi[3]);
168
- }
169
- };
170
-
171
- template <typename T>
172
- struct mma_B_J8K4 {
173
- static_assert(sizeof(T) == 4, "bad type size");
174
-
175
- static constexpr int J = 8;
176
- static constexpr int K = 4;
177
- static constexpr int ne = 1;
178
-
179
- T x[ne];
180
-
181
- static __device__ __forceinline__ int get_j(const int /* l */) {
182
- const int ret = threadIdx.x / K;
183
- GGML_CUDA_ASSUME(ret >= 0);
184
- GGML_CUDA_ASSUME(ret < J);
185
- return ret;
186
- }
187
-
188
- static __device__ __forceinline__ int get_k(const int /* l */) {
189
- const int ret = threadIdx.x % K;
190
- GGML_CUDA_ASSUME(ret >= 0);
191
- GGML_CUDA_ASSUME(ret < K);
192
- return ret;
193
- }
194
-
195
- __device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) {
196
- #pragma unroll
197
- for (int l = 0; l < ne; ++l) {
198
- x[l] = xs0[get_j(l)*stride + get_k(l)];
199
- }
200
- }
201
-
202
- __device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) {
203
  #ifdef NEW_MMA_AVAILABLE
204
- int * xi = (int *) x;
205
- const int * xs = (const int *) xs0 + (threadIdx.x%J)*stride;
206
- asm("ldmatrix.sync.aligned.m8n8.x1.b16 {%0}, [%1];"
207
- : "+r"(xi[0]) : "l"(xs));
 
208
  #else
209
- load_generic(xs0, stride);
210
  #endif // NEW_MMA_AVAILABLE
211
  }
212
- };
213
-
214
- template <typename T>
215
- struct mma_B_J8K8 {
216
- static_assert(sizeof(T) == 4, "bad type size");
217
-
218
- static constexpr int J = 8;
219
- static constexpr int K = 8;
220
- static constexpr int ne = 2;
221
 
222
- T x[ne];
223
-
224
- static __device__ __forceinline__ int get_j(const int /* l */) {
225
- const int ret = threadIdx.x / (K/2);
226
- GGML_CUDA_ASSUME(ret >= 0);
227
- GGML_CUDA_ASSUME(ret < J);
228
- return ret;
229
- }
230
-
231
- static __device__ __forceinline__ int get_k(const int l) {
232
- const int ret = l * (K/2) + threadIdx.x % (K/2);
233
- GGML_CUDA_ASSUME(ret >= 0);
234
- GGML_CUDA_ASSUME(ret < K);
235
- return ret;
236
- }
237
-
238
- __device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) {
239
- #pragma unroll
240
- for (int l = 0; l < ne; ++l) {
241
- x[l] = xs0[get_j(l)*stride + get_k(l)];
242
- }
243
- }
244
-
245
- __device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) {
246
  #ifdef NEW_MMA_AVAILABLE
247
- int * xi = (int *) x;
248
- const int * xs = (const int *) xs0 + (threadIdx.x%J)*stride + ((threadIdx.x/J)*(K/2)) % K;
249
- asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
250
- : "+r"(xi[0]), "+r"(xi[1])
251
  : "l"(xs));
252
  #else
253
- load_generic(xs0, stride);
 
 
 
254
  #endif // NEW_MMA_AVAILABLE
255
  }
256
- };
257
-
258
- template <typename T>
259
- struct mma_C_I16J8 {};
260
-
261
- template <>
262
- struct mma_C_I16J8<int> {
263
- static constexpr int I = 16;
264
- static constexpr int J = 8;
265
- static constexpr int ne = 4;
266
 
267
- int x[ne] = {0};
268
-
269
- static __device__ __forceinline__ int get_i(const int l) {
270
- const int ret = (l/2) * (I/2) + threadIdx.x / (J/2);
271
- GGML_CUDA_ASSUME(ret >= 0);
272
- GGML_CUDA_ASSUME(ret < I);
273
- return ret;
274
- }
275
-
276
- static __device__ __forceinline__ int get_j(const int l) {
277
- const int ret = 2 * (threadIdx.x % (J/2)) + l%2;
278
- GGML_CUDA_ASSUME(ret >= 0);
279
- GGML_CUDA_ASSUME(ret < J);
280
- return ret;
281
- }
282
-
283
- __device__ __forceinline__ void mma(const mma_A_I16K4<int> & mma_A, const mma_B_J8K4<int> & mma_B) {
284
  #ifdef NEW_MMA_AVAILABLE
285
  #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
286
  asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
287
- : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
288
- : "r"(mma_A.x[0]), "r"(mma_A.x[1]), "r"(mma_B.x[0]));
289
  #else
290
  // On Turing m16n8k16 mma is not available, use 2x m8n8k16 mma instead:
291
  asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
292
- : "+r"(x[0]), "+r"(x[1])
293
- : "r"(mma_A.x[0]), "r"(mma_B.x[0]));
294
  asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
295
- : "+r"(x[2]), "+r"(x[3])
296
- : "r"(mma_A.x[1]), "r"(mma_B.x[0]));
297
  #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
298
  #else
299
- GGML_UNUSED(mma_A);
300
- GGML_UNUSED(mma_B);
 
301
  NO_DEVICE_CODE;
302
  #endif // NEW_MMA_AVAILABLE
303
  }
304
 
305
- __device__ __forceinline__ void mma(const mma_A_I16K8<int> & mma_A, const mma_B_J8K8<int> & mma_B) {
 
306
  #ifdef NEW_MMA_AVAILABLE
307
  #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
308
  asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
309
- : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
310
- : "r"(mma_A.x[0]), "r"(mma_A.x[1]), "r"(mma_A.x[2]), "r"(mma_A.x[3]), "r"(mma_B.x[0]), "r"(mma_B.x[1]));
311
  #else
312
  // On Turing m16n8k32 mma is not available, use 4x m8n8k16 mma instead:
313
  asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
314
- : "+r"(x[0]), "+r"(x[1])
315
- : "r"(mma_A.x[0]), "r"(mma_B.x[0]));
316
  asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
317
- : "+r"(x[2]), "+r"(x[3])
318
- : "r"(mma_A.x[1]), "r"(mma_B.x[0]));
319
  asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
320
- : "+r"(x[0]), "+r"(x[1])
321
- : "r"(mma_A.x[2]), "r"(mma_B.x[1]));
322
  asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
323
- : "+r"(x[2]), "+r"(x[3])
324
- : "r"(mma_A.x[3]), "r"(mma_B.x[1]));
325
  #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
326
  #else
327
- GGML_UNUSED(mma_A);
328
- GGML_UNUSED(mma_B);
 
329
  NO_DEVICE_CODE;
330
  #endif // NEW_MMA_AVAILABLE
331
  }
332
- };
333
-
334
- template <>
335
- struct mma_C_I16J8<half2> {
336
- static constexpr int I = 16;
337
- static constexpr int J = 4;
338
- static constexpr int ne = 2;
339
-
340
- half2 x[ne] = {{0.0f, 0.0f}, {0.0f, 0.0f}};
341
-
342
- static __device__ __forceinline__ int get_i(const int l) {
343
- const int ret = l * (I/2) + threadIdx.x / J;
344
- GGML_CUDA_ASSUME(ret >= 0);
345
- GGML_CUDA_ASSUME(ret < I);
346
- return ret;
347
- }
348
 
349
- static __device__ __forceinline__ int get_j(const int /* l */) {
350
- const int ret = threadIdx.x % J;
351
- GGML_CUDA_ASSUME(ret >= 0);
352
- GGML_CUDA_ASSUME(ret < J);
353
- return ret;
354
- }
355
-
356
- __device__ __forceinline__ void mma(const mma_A_I16K8<half2> & mma_A, const mma_B_J8K8<half2> & mma_B) {
357
  #ifdef NEW_MMA_AVAILABLE
358
- int * Axi = (int *) mma_A.x;
359
- int * Bxi = (int *) mma_B.x;
360
- int * xi = (int *) x;
361
  #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
362
  asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};"
363
- : "+r"(xi[0]), "+r"(xi[1])
364
  : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
365
  #else
366
  // On Turing m16n8k16 mma is not available, use 2x m8n8k8 mma instead:
367
  asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
368
- : "+r"(xi[0]), "+r"(xi[1])
369
  : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
370
  asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
371
- : "+r"(xi[0]), "+r"(xi[1])
372
  : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]));
373
  #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
374
  #else
375
- GGML_UNUSED(mma_A);
376
- GGML_UNUSED(mma_B);
 
377
  NO_DEVICE_CODE;
378
  #endif // NEW_MMA_AVAILABLE
379
  }
380
 
381
- __device__ __forceinline__ mma_B_J8K8<half2> to_mma_B() {
382
- mma_B_J8K8<half2> mma_B;
383
-
384
- int * xi = (int *) x;
385
- int * Bxi = (int *) mma_B.x;
386
- Bxi[0] = ggml_cuda_movmatrix(xi[0]);
387
- Bxi[1] = ggml_cuda_movmatrix(xi[1]);
388
-
389
- return mma_B;
390
- }
391
- };
392
-
393
- template <>
394
- struct mma_C_I16J8<float> {
395
- static constexpr int I = 16;
396
- static constexpr int J = 8;
397
- static constexpr int ne = 4;
398
-
399
- float x[ne] = {0.0f, 0.0f, 0.0f, 0.0f};
400
-
401
- static __device__ __forceinline__ int get_i(const int l) {
402
- const int ret = (l/2) * (I/2) + threadIdx.x / (J/2);
403
- GGML_CUDA_ASSUME(ret >= 0);
404
- GGML_CUDA_ASSUME(ret < I);
405
- return ret;
406
- }
407
-
408
- static __device__ __forceinline__ int get_j(const int l) {
409
- const int ret = 2 * (threadIdx.x % (J/2)) + l%2;
410
- GGML_CUDA_ASSUME(ret >= 0);
411
- GGML_CUDA_ASSUME(ret < J);
412
- return ret;
413
- }
414
-
415
- __device__ __forceinline__ void mma(const mma_A_I16K8<half2> & mma_A, const mma_B_J8K8<half2> & mma_B) {
416
  #ifdef NEW_MMA_AVAILABLE
417
- int * Axi = (int *) mma_A.x;
418
- int * Bxi = (int *) mma_B.x;
419
- int * xi = (int *) x;
420
  #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
421
  asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
422
- : "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3])
423
  : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
424
  #else
425
  // On Turing m16n8k16 mma is not available, use 2x m8n8k8 mma instead:
426
  asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
427
- : "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3])
428
  : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
429
  asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
430
- : "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3])
431
  : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]));
432
  #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
433
  #else
434
- GGML_UNUSED(mma_A);
435
- GGML_UNUSED(mma_B);
 
436
  NO_DEVICE_CODE;
437
  #endif // NEW_MMA_AVAILABLE
438
  }
439
 
440
- __device__ __forceinline__ mma_B_J8K8<half2> to_mma_B() {
441
- mma_B_J8K8<half2> mma_B;
442
- mma_B.x[0] = make_half2(x[0], x[1]);
443
- mma_B.x[1] = make_half2(x[2], x[3]);
444
-
445
- int * Bxi = (int *) mma_B.x;
446
- Bxi[0] = ggml_cuda_movmatrix(Bxi[0]);
447
- Bxi[1] = ggml_cuda_movmatrix(Bxi[1]);
448
-
449
- return mma_B;
450
- }
451
-
452
- __device__ __forceinline__ void load_generic(const float * __restrict__ xs0, const int & stride) {
453
- #pragma unroll
454
- for (int l = 0; l < ne; ++l) {
455
- x[l] = xs0[get_j(l)*stride + get_i(l)];
456
- }
457
- }
458
- };
 
4
  // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-multiply-accumulate-operation-using-mma-instruction
5
  //
6
  // Like with nvcuda::wmma there are three types of matrix tiles: A, B, and C with A @ B = C.
7
+ // A is a row-major matrix with shape M x K.
8
+ // B is a column-major matrix with shape K x N.
9
+ // C is a column-major matrix with shape M x N.
10
+ // A, B, and C are represented using the same fundamental data type: a row-major matrix with I rows and J columns.
11
+ // Note that J is measured in physical 32 bit elements instead of logical elements.
12
+ // The methods get_i and get_j can be used to get the physical 32 bit index of the lth element of a thread within a tile.
13
  // All matrix tiles have ne physical 32 bit elements per warp.
14
  //
15
  // As described in the documentation, all pointers for load_ldmatrix must be to shared memory and aligned to 16 bytes.
 
24
 
25
  #ifdef NEW_MMA_AVAILABLE
26
  asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
27
+ : "=r"(ret) : "r"(x));
28
  #else
29
  NO_DEVICE_CODE;
30
  #endif // defined(NEW_MMA_AVAILABLE)
 
53
 
54
  #endif // CUDART_VERSION >= 11080
55
 
56
+ static __device__ __forceinline__ half2 ggml_cuda_movmatrix(const half2 x) {
57
+ half2 ret;
58
+ *((int *) &ret) = ggml_cuda_movmatrix(*((const int *) &x));
59
+ return ret;
60
+ }
61
 
62
+ namespace ggml_cuda_mma {
63
+
64
+ template <int I_, int J_, typename T>
65
+ struct tile {
66
+ static constexpr int I = I_;
67
+ static constexpr int J = J_;
68
+ static constexpr int ne = I * J / WARP_SIZE;
69
+ T x[ne] = {0};
70
+
71
+ static __device__ __forceinline__ int get_i(const int l) {
72
+ if constexpr (I == 8 && (J == 4 || J == 8)) {
73
+ return threadIdx.x / 4;
74
+ } else if constexpr (I == 16 && J == 8) {
75
+ return (l / 2) * 8 + threadIdx.x / 4;
76
+ } else {
77
+ static_assert(I == -1 && J == -1, "template specialization not implemented");
78
+ }
79
+ }
80
 
81
+ static __device__ __forceinline__ int get_j(const int l) {
82
+ if constexpr (I == 8 && J == 4) {
83
+ return threadIdx.x % 4;
84
+ } else if constexpr (I == 8 && J == 8) {
85
+ return 4 * l + threadIdx.x % 4;
86
+ } else if constexpr (I == 16 && J == 8) {
87
+ return 2 * (threadIdx.x % 4) + l % 2;
88
+ } else {
89
+ static_assert(I == -1 && J == -1, "template specialization not implemented");
90
+ }
91
+ }
92
+ };
93
+
94
+ template <int I_, int J_>
95
+ struct tile<I_, J_, half2> {
96
+ static constexpr int I = I_;
97
+ static constexpr int J = J_;
98
+ static constexpr int ne = I * J / WARP_SIZE;
99
+ half2 x[ne] = {{0.0f, 0.0f}};
100
+
101
+ static __device__ __forceinline__ int get_i(const int l) {
102
+ if constexpr (I == 8 && J == 8) {
103
+ return threadIdx.x / 4;
104
+ } else if constexpr (I == 16 && J == 4) {
105
+ return l * 8 + threadIdx.x / 4;
106
+ } else if constexpr (I == 16 && J == 8) {
107
+ return (l % 2) * 8 + threadIdx.x / 4;
108
+ } else {
109
+ static_assert(I == -1 && J == -1, "template specialization not implemented");
110
+ }
111
+ }
112
 
113
+ static __device__ __forceinline__ int get_j(const int l) {
114
+ if constexpr (I == 8 && J == 8) {
115
+ return l * 4 + threadIdx.x % 4;
116
+ } else if constexpr (I == 16 && J == 4) {
117
+ return threadIdx.x % 4;
118
+ } else if constexpr (I == 16 && J == 8) {
119
+ return (l / 2) * 4 + threadIdx.x % 4;
120
+ } else {
121
+ static_assert(I == -1 && J == -1, "template specialization not implemented");
122
+ }
123
+ }
124
+ };
125
 
126
+ template <int I, int J>
127
+ static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
128
+ tile<I, J/2, half2> ret;
129
  #pragma unroll
130
+ for (int l0 = 0; l0 < tile_float.ne; l0 += 2) {
131
+ ret.x[l0/2] = make_half2(tile_float.x[l0 + 0], tile_float.x[l0 + 1]);
132
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  return ret;
134
  }
135
 
136
+ static __device__ __forceinline__ tile<8, 8, half2> get_transposed(const tile<16, 4, half2> & t) {
137
+ tile<8, 8, half2> ret;
138
+ ret.x[0] = ggml_cuda_movmatrix(t.x[0]);
139
+ ret.x[1] = ggml_cuda_movmatrix(t.x[1]);
140
+
141
  return ret;
142
  }
143
 
144
+ template <int I, int J, typename T>
145
+ static __device__ __forceinline__ void load_generic(tile<I, J, T> & t, const T * __restrict__ xs0, const int stride) {
146
  #pragma unroll
147
+ for (int l = 0; l < t.ne; ++l) {
148
+ t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
149
  }
150
  }
151
 
152
+ template <typename T>
153
+ static __device__ __forceinline__ void load_ldmatrix(
154
+ tile<8, 8, T> & t, const T * __restrict__ xs0, const int stride) {
155
  #ifdef NEW_MMA_AVAILABLE
156
+ int * xi = (int *) t.x;
157
+ const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + ((threadIdx.x / t.I) * (t.J / 2)) % t.J;
158
+ asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
159
+ : "=r"(xi[0]), "=r"(xi[1])
160
  : "l"(xs));
161
  #else
162
+ load_generic(t, xs0, stride);
 
 
163
  #endif // NEW_MMA_AVAILABLE
164
  }
165
 
166
+ template <typename T>
167
+ static __device__ __forceinline__ void load_ldmatrix(
168
+ tile<16, 4, T> & t, const T * __restrict__ xs0, const int stride) {
169
  #ifdef NEW_MMA_AVAILABLE
170
+ int * xi = (int *) t.x;
171
+ const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride;
172
+ asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
173
+ : "=r"(xi[0]), "=r"(xi[1])
174
  : "l"(xs));
175
  #else
176
+ load_generic(xs0, stride);
 
 
177
  #endif // NEW_MMA_AVAILABLE
178
  }
179
 
180
+ template <typename T>
181
+ static __device__ __forceinline__ void load_ldmatrix(
182
+ tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  #ifdef NEW_MMA_AVAILABLE
184
+ int * xi = (int * ) t.x;
185
+ const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
186
+ asm volatile("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
187
+ : "=r"(xi[0]), "=r"(xi[1]), "=r"(xi[2]), "=r"(xi[3])
188
+ : "l"(xs));
189
  #else
190
+ load_generic(t, xs0, stride);
191
  #endif // NEW_MMA_AVAILABLE
192
  }
 
 
 
 
 
 
 
 
 
193
 
194
+ template <typename T>
195
+ static __device__ __forceinline__ void load_ldmatrix_trans(
196
+ tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
  #ifdef NEW_MMA_AVAILABLE
198
+ int * xi = (int * ) t.x;
199
+ const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
200
+ asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.b16 {%0, %1, %2, %3}, [%4];"
201
+ : "=r"(xi[0]), "=r"(xi[2]), "=r"(xi[1]), "=r"(xi[3])
202
  : "l"(xs));
203
  #else
204
+ GGML_UNUSED(t);
205
+ GGML_UNUSED(xs0);
206
+ GGML_UNUSED(stride);
207
+ NO_DEVICE_CODE;
208
  #endif // NEW_MMA_AVAILABLE
209
  }
 
 
 
 
 
 
 
 
 
 
210
 
211
+ static __device__ __forceinline__ void mma(
212
+ tile<16, 8, int> & D, const tile<16, 4, int> & A, const tile<8, 4, int> & B) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
  #ifdef NEW_MMA_AVAILABLE
214
  #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
215
  asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
216
+ : "+r"(D.x[0]), "+r"(D.x[1]), "+r"(D.x[2]), "+r"(D.x[3])
217
+ : "r"(A.x[0]), "r"(A.x[1]), "r"(B.x[0]));
218
  #else
219
  // On Turing m16n8k16 mma is not available, use 2x m8n8k16 mma instead:
220
  asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
221
+ : "+r"(D.x[0]), "+r"(D.x[1])
222
+ : "r"(A.x[0]), "r"(B.x[0]));
223
  asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
224
+ : "+r"(D.x[2]), "+r"(D.x[3])
225
+ : "r"(A.x[1]), "r"(B.x[0]));
226
  #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
227
  #else
228
+ GGML_UNUSED(D);
229
+ GGML_UNUSED(A);
230
+ GGML_UNUSED(B);
231
  NO_DEVICE_CODE;
232
  #endif // NEW_MMA_AVAILABLE
233
  }
234
 
235
+ static __device__ __forceinline__ void mma(
236
+ tile<16, 8, int> & D, const tile<16, 8, int> & A, const tile<8, 8, int> & B) {
237
  #ifdef NEW_MMA_AVAILABLE
238
  #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
239
  asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
240
+ : "+r"(D.x[0]), "+r"(D.x[1]), "+r"(D.x[2]), "+r"(D.x[3])
241
+ : "r"(A.x[0]), "r"(A.x[1]), "r"(A.x[2]), "r"(A.x[3]), "r"(B.x[0]), "r"(B.x[1]));
242
  #else
243
  // On Turing m16n8k32 mma is not available, use 4x m8n8k16 mma instead:
244
  asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
245
+ : "+r"(D.x[0]), "+r"(D.x[1])
246
+ : "r"(A.x[0]), "r"(B.x[0]));
247
  asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
248
+ : "+r"(D.x[2]), "+r"(D.x[3])
249
+ : "r"(A.x[1]), "r"(B.x[0]));
250
  asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
251
+ : "+r"(D.x[0]), "+r"(D.x[1])
252
+ : "r"(A.x[2]), "r"(B.x[1]));
253
  asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
254
+ : "+r"(D.x[2]), "+r"(D.x[3])
255
+ : "r"(A.x[3]), "r"(B.x[1]));
256
  #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
257
  #else
258
+ GGML_UNUSED(D);
259
+ GGML_UNUSED(A);
260
+ GGML_UNUSED(B);
261
  NO_DEVICE_CODE;
262
  #endif // NEW_MMA_AVAILABLE
263
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
 
265
+ static __device__ __forceinline__ void mma(
266
+ tile<16, 4, half2> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
 
 
 
 
 
 
267
  #ifdef NEW_MMA_AVAILABLE
268
+ const int * Axi = (const int *) A.x;
269
+ const int * Bxi = (const int *) B.x;
270
+ int * Dxi = (int *) D.x;
271
  #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
272
  asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};"
273
+ : "+r"(Dxi[0]), "+r"(Dxi[1])
274
  : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
275
  #else
276
  // On Turing m16n8k16 mma is not available, use 2x m8n8k8 mma instead:
277
  asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
278
+ : "+r"(Dxi[0]), "+r"(Dxi[1])
279
  : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
280
  asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
281
+ : "+r"(Dxi[0]), "+r"(Dxi[1])
282
  : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]));
283
  #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
284
  #else
285
+ GGML_UNUSED(D);
286
+ GGML_UNUSED(A);
287
+ GGML_UNUSED(B);
288
  NO_DEVICE_CODE;
289
  #endif // NEW_MMA_AVAILABLE
290
  }
291
 
292
+ static __device__ __forceinline__ void mma(
293
+ tile<16, 8, float> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
294
  #ifdef NEW_MMA_AVAILABLE
295
+ const int * Axi = (const int *) A.x;
296
+ const int * Bxi = (const int *) B.x;
297
+ int * Dxi = (int *) D.x;
298
  #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
299
  asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
300
+ : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
301
  : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
302
  #else
303
  // On Turing m16n8k16 mma is not available, use 2x m8n8k8 mma instead:
304
  asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
305
+ : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
306
  : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
307
  asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
308
+ : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
309
  : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]));
310
  #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
311
  #else
312
+ GGML_UNUSED(D);
313
+ GGML_UNUSED(A);
314
+ GGML_UNUSED(B);
315
  NO_DEVICE_CODE;
316
  #endif // NEW_MMA_AVAILABLE
317
  }
318
 
319
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ggml/src/ggml-cuda/mmq.cuh CHANGED
@@ -7,6 +7,8 @@
7
  #include <climits>
8
  #include <cstdint>
9
 
 
 
10
  #define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available.
11
  #define MMQ_ITER_K 256
12
  #define MMQ_NWARPS 8
@@ -647,15 +649,15 @@ template <int mmq_x, int mmq_y, int nwarps, mmq_q8_1_ds_layout ds_layout>
647
  static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
648
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
649
 
650
- typedef mma_A_I16K8<int> mma_A;
651
- typedef mma_B_J8K8<int> mma_B;
652
- typedef mma_C_I16J8<int> mma_C;
653
 
654
  constexpr int granularity = mmq_get_granularity_device(mmq_x);
655
  constexpr int rows_per_warp = 2 * granularity;
656
- constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
657
 
658
- y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
659
 
660
  const int * x_qs = (const int *) x;
661
  const float * x_df = (const float *) x_qs + 2*WARP_SIZE;
@@ -663,8 +665,8 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
663
  const float * y_df = (const float *) y;
664
  const half2 * y_ds = (const half2 *) y;
665
 
666
- mma_A A[ntx][WARP_SIZE/QI8_0];
667
- float dA[ntx][mma_C::ne/2][WARP_SIZE/QI8_0];
668
 
669
  const int i0 = (threadIdx.y/ntx)*rows_per_warp;
670
 
@@ -674,12 +676,12 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
674
  for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
675
  const int k0 = k00 + k01;
676
 
677
- A[n][k01/QI8_0].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
678
  }
679
 
680
  #pragma unroll
681
- for (int l = 0; l < mma_C::ne/2; ++l) {
682
- const int i = i0 + n*mma_A::I + mma_C::get_i(2*l);
683
 
684
  #pragma unroll
685
  for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
@@ -691,17 +693,17 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
691
  }
692
 
693
  #pragma unroll
694
- for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
695
  #pragma unroll
696
  for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
697
- mma_B B;
698
- float dB[mma_C::ne/2];
699
 
700
- B.load_generic(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix
701
 
702
  #pragma unroll
703
- for (int l = 0; l < mma_C::ne/2; ++l) {
704
- const int j = j0 + mma_C::get_j(l);
705
 
706
  if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) {
707
  dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
@@ -712,12 +714,12 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
712
 
713
  #pragma unroll
714
  for (int n = 0; n < ntx; ++n) {
715
- mma_C C;
716
- C.mma(A[n][k01/QI8_0], B);
717
 
718
  #pragma unroll
719
- for (int l = 0; l < mma_C::ne; ++l) {
720
- sum[(j0/mma_C::J + n)*mma_C::ne + l] += C.x[l]*dA[n][l/2][k01/QI8_0]*dB[l%2];
721
  }
722
  }
723
  }
@@ -758,23 +760,23 @@ template <int mmq_x, int mmq_y, int nwarps>
758
  static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
759
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
760
 
761
- typedef mma_A_I16K8<int> mma_A;
762
- typedef mma_B_J8K8<int> mma_B;
763
- typedef mma_C_I16J8<int> mma_C;
764
 
765
  constexpr int granularity = mmq_get_granularity_device(mmq_x);
766
  constexpr int rows_per_warp = 2 * granularity;
767
- constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
768
 
769
- y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
770
 
771
  const int * x_qs = (const int *) x;
772
  const half2 * x_dm = (const half2 *) x_qs + 2*WARP_SIZE;
773
  const int * y_qs = (const int *) y + 4;
774
  const half2 * y_dm = (const half2 *) y;
775
 
776
- mma_A A[ntx][WARP_SIZE/QI8_1];
777
- float2 dmA[ntx][mma_C::ne/2][WARP_SIZE/QI8_1];
778
 
779
  const int i0 = (threadIdx.y/ntx)*rows_per_warp;
780
 
@@ -784,12 +786,12 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
784
  for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
785
  const int k0 = k00 + k01;
786
 
787
- A[n][k01/QI8_1].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);
788
  }
789
 
790
  #pragma unroll
791
- for (int l = 0; l < mma_C::ne/2; ++l) {
792
- const int i = i0 + n*mma_A::I + mma_C::get_i(2*l);
793
 
794
  #pragma unroll
795
  for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
@@ -801,30 +803,30 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
801
  }
802
 
803
  #pragma unroll
804
- for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
805
  #pragma unroll
806
  for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
807
- mma_B B;
808
- float2 dsB[mma_C::ne/2];
809
 
810
- B.load_generic(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix
811
 
812
  #pragma unroll
813
- for (int l = 0; l < mma_C::ne/2; ++l) {
814
- const int j = j0 + mma_C::get_j(l);
815
 
816
  dsB[l] = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]);
817
  }
818
 
819
  #pragma unroll
820
  for (int n = 0; n < ntx; ++n) {
821
- mma_C C;
822
- C.mma(A[n][k01/QI8_1], B);
823
 
824
  #pragma unroll
825
- for (int l = 0; l < mma_C::ne; ++l) {
826
- sum[(j0/mma_C::J + n)*mma_C::ne + l] += dmA[n][l/2][k01/QI8_1].x*dsB[l%2].x*C.x[l];
827
- sum[(j0/mma_C::J + n)*mma_C::ne + l] += dmA[n][l/2][k01/QI8_1].y*dsB[l%2].y;
828
  }
829
  }
830
  }
@@ -868,26 +870,26 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
868
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
869
  #ifdef NEW_MMA_AVAILABLE
870
 
871
- typedef mma_A_I16K4<int> mma_A;
872
- typedef mma_A_I16K8<int> mma_A_K8;
873
- typedef mma_B_J8K4<int> mma_B;
874
- typedef mma_C_I16J8<int> mma_C;
875
 
876
  constexpr int granularity = mmq_get_granularity_device(mmq_x);
877
  constexpr int rows_per_warp = 2 * granularity;
878
- constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
879
 
880
- y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
881
 
882
  const int * x_qs = (const int *) x;
883
  const float * x_df = (const float *) x_qs + WARP_SIZE*2;
884
  const int * y_qs = (const int *) y + 4;
885
  const float * y_df = (const float *) y;
886
 
887
- const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
888
 
889
- mma_A A[ntx][8];
890
- float dA[ntx][mma_C::ne/2][8];
891
 
892
  #pragma unroll
893
  for (int n = 0; n < ntx; ++n) {
@@ -895,12 +897,12 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
895
  for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
896
  const int k0 = k00 + k01;
897
 
898
- ((mma_A_K8 *) A[n])[k01/8].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
899
  }
900
 
901
  #pragma unroll
902
- for (int l = 0; l < mma_C::ne/2; ++l) {
903
- const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
904
 
905
  #pragma unroll
906
  for (int k01 = 0; k01 < WARP_SIZE; k01 += 4) {
@@ -912,32 +914,32 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
912
  }
913
 
914
  #pragma unroll
915
- for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
916
  #pragma unroll
917
  for (int k01 = 0; k01 < WARP_SIZE; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) {
918
- mma_B B[2];
919
- float dB[mma_C::ne/2];
920
 
921
  // Here load_generic is faster than load_ldmatrix.
922
- B[0].load_generic(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K);
923
- B[1].load_generic(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K);
924
 
925
  #pragma unroll
926
- for (int l = 0; l < mma_C::ne/2; ++l) {
927
- const int j = j0 + mma_C::get_j(l);
928
 
929
  dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
930
  }
931
 
932
  #pragma unroll
933
  for (int n = 0; n < ntx; ++n) {
934
- mma_C C[2];
935
- C[0].mma(A[n][k01/4 + 0], B[0]);
936
- C[1].mma(A[n][k01/4 + 1], B[1]);
937
 
938
  #pragma unroll
939
- for (int l = 0; l < mma_C::ne; ++l) {
940
- sum[(j0/mma_C::J + n)*mma_C::ne + l] += dB[l%2]*(C[0].x[l]*dA[n][l/2][k01/4 + 0] + C[1].x[l]*dA[n][l/2][k01/4 + 1]);
941
  }
942
  }
943
  }
@@ -1056,27 +1058,27 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
1056
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
1057
  #ifdef NEW_MMA_AVAILABLE
1058
 
1059
- typedef mma_A_I16K4<int> mma_A;
1060
- typedef mma_A_I16K8<int> mma_A_K8;
1061
- typedef mma_B_J8K4<int> mma_B;
1062
- typedef mma_C_I16J8<int> mma_C;
1063
 
1064
  constexpr int granularity = mmq_get_granularity_device(mmq_x);
1065
  constexpr int rows_per_warp = 2 * granularity;
1066
- constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
1067
 
1068
- y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
1069
 
1070
  const int * x_qs = (const int *) x;
1071
  const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE*2;
1072
  const int * y_qs = (const int *) y + 4;
1073
  const half2 * y_ds = (const half2 *) y;
1074
 
1075
- const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
1076
 
1077
- mma_A A[ntx][8];
1078
- float dA[ntx][mma_C::ne/2][8];
1079
- float mA[ntx][mma_C::ne/2][8];
1080
 
1081
  #pragma unroll
1082
  for (int n = 0; n < ntx; ++n) {
@@ -1084,15 +1086,15 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
1084
  for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
1085
  const int k0 = k00 + k01;
1086
 
1087
- ((mma_A_K8 *) A[n])[k01/QI8_1].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
1088
  }
1089
  }
1090
 
1091
  #pragma unroll
1092
  for (int n = 0; n < ntx; ++n) {
1093
  #pragma unroll
1094
- for (int l = 0; l < mma_C::ne/2; ++l) {
1095
- const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
1096
 
1097
  #pragma unroll
1098
  for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1/2) {
@@ -1107,58 +1109,58 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
1107
  }
1108
 
1109
  #pragma unroll
1110
- for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
1111
- float2 dB[mma_C::ne/2];
1112
 
1113
  #pragma unroll
1114
- for (int l = 0; l < mma_C::ne/2; ++l) {
1115
- const int j = j0 + mma_C::get_j(l);
1116
 
1117
  dB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K]);
1118
  }
1119
 
1120
  #pragma unroll
1121
  for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
1122
- mma_B B[2];
1123
 
1124
  // Here load_generic is faster than load_ldmatrix.
1125
- B[0].load_generic(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K);
1126
- B[1].load_generic(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K);
1127
 
1128
- mma_C Cm[2];
1129
  if (k01 >= WARP_SIZE * 3/4) {
1130
- mma_A A1;
1131
  A1.x[0] = 0x01010101;
1132
  A1.x[1] = 0x01010101;
1133
- Cm[0].mma(A1, B[0]);
1134
- Cm[1].mma(A1, B[1]);
1135
  }
1136
 
1137
  #pragma unroll
1138
  for (int n = 0; n < ntx; ++n) {
1139
- mma_C Cd[2];
1140
 
1141
- Cd[0].mma(A[n][k01/4 + 0], B[0]);
1142
- Cd[1].mma(A[n][k01/4 + 1], B[1]);
1143
 
1144
  #pragma unroll
1145
- for (int l = 0; l < mma_C::ne; ++l) {
1146
  float tmp = Cd[0].x[l]*dA[n][l/2][k01/4 + 0] + Cd[1].x[l]*dA[n][l/2][k01/4 + 1];
1147
  if (k01 >= WARP_SIZE * 3/4) {
1148
  tmp -= Cm[0].x[l]*mA[n][l/2][k01/4 + 0] + Cm[1].x[l]*mA[n][l/2][k01/4 + 1];
1149
  }
1150
- sum[(j0/mma_C::J + n)*mma_C::ne + l] += tmp*(k01 < WARP_SIZE/2 ? dB[l%2].x : dB[l%2].y);
1151
  }
1152
  }
1153
  }
1154
 
1155
  #pragma unroll
1156
  for (int k01 = 0; k01 < WARP_SIZE * 3/4; k01 += QI8_1) {
1157
- float2 sB[mma_C::ne/2];
1158
 
1159
  #pragma unroll
1160
- for (int l = 0; l < mma_C::ne/2; ++l) {
1161
- const int j = j0 + mma_C::get_j(l);
1162
 
1163
  sB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
1164
  }
@@ -1166,9 +1168,9 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
1166
  #pragma unroll
1167
  for (int n = 0; n < ntx; ++n) {
1168
  #pragma unroll
1169
- for (int l = 0; l < mma_C::ne; ++l) {
1170
- sum[(j0/mma_C::J + n)*mma_C::ne + l] -= mA[n][l/2][k01/4 + 0]*sB[l%2].x;
1171
- sum[(j0/mma_C::J + n)*mma_C::ne + l] -= mA[n][l/2][k01/4 + 1]*sB[l%2].y;
1172
  }
1173
  }
1174
  }
@@ -1708,15 +1710,15 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
1708
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
1709
  #ifdef NEW_MMA_AVAILABLE
1710
 
1711
- typedef mma_A_I16K4<int> mma_A;
1712
- typedef mma_B_J8K4<int> mma_B;
1713
- typedef mma_C_I16J8<int> mma_C;
1714
 
1715
  constexpr int granularity = mmq_get_granularity_device(mmq_x);
1716
  constexpr int rows_per_warp = 2 * granularity;
1717
- constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
1718
 
1719
- y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
1720
 
1721
  const int * x_qs = (const int *) x;
1722
  const float * x_df = (const float *) x_qs + WARP_SIZE*2;
@@ -1724,11 +1726,11 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
1724
  const int * y_qs = (const int *) y + 4;
1725
  const float * y_df = (const float *) y;
1726
 
1727
- const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
1728
 
1729
- mma_A A[ntx][8];
1730
- int scA[ntx][mma_C::ne/2][8];
1731
- float dA[ntx][mma_C::ne/2];
1732
 
1733
  #pragma unroll
1734
  for (int n = 0; n < ntx; ++n) {
@@ -1736,8 +1738,8 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
1736
  for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
1737
  const int k0 = k00 + k01;
1738
 
1739
- A[n][k01/4 + 0].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + 0), MMQ_MMA_TILE_X_K_Q6_K);
1740
- A[n][k01/4 + 1].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + mma_A::K), MMQ_MMA_TILE_X_K_Q6_K);
1741
  }
1742
 
1743
  #pragma unroll
@@ -1745,8 +1747,8 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
1745
  const int k0 = k00 + k01;
1746
 
1747
  #pragma unroll
1748
- for (int l = 0; l < mma_C::ne/2; ++l) {
1749
- const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
1750
 
1751
  const int sc_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + k0/16];
1752
  const int8_t * sc = (const int8_t *) &sc_packed;
@@ -1759,41 +1761,41 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
1759
  }
1760
 
1761
  #pragma unroll
1762
- for (int l = 0; l < mma_C::ne/2; ++l) {
1763
- const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
1764
 
1765
  dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q6_K];
1766
  }
1767
  }
1768
 
1769
  #pragma unroll
1770
- for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
1771
- float tmp[ntx][mma_C::ne] = {{0.0f}};
1772
 
1773
  #pragma unroll
1774
  for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
1775
- mma_B B[2];
1776
- float dB[mma_C::ne/2];
1777
 
1778
  // Here load_generic is faster than load_ldmatrix.
1779
- B[0].load_generic(y_qs + j0*MMQ_TILE_Y_K + 0 + k01, MMQ_TILE_Y_K);
1780
- B[1].load_generic(y_qs + j0*MMQ_TILE_Y_K + mma_B::K + k01, MMQ_TILE_Y_K);
1781
 
1782
  #pragma unroll
1783
- for (int l = 0; l < mma_C::ne/2; ++l) {
1784
- const int j = j0 + mma_C::get_j(l);
1785
 
1786
  dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
1787
  }
1788
 
1789
  #pragma unroll
1790
  for (int n = 0; n < ntx; ++n) {
1791
- mma_C C[2];
1792
- C[0].mma(A[n][k01/4 + 0], B[0]);
1793
- C[1].mma(A[n][k01/4 + 1], B[1]);
1794
 
1795
  #pragma unroll
1796
- for (int l = 0; l < mma_C::ne; ++l) {
1797
  tmp[n][l] += (C[0].x[l]*scA[n][l/2][k01/4 + 0] + C[1].x[l]*scA[n][l/2][k01/4 + 1])*dB[l%2];
1798
  }
1799
  }
@@ -1802,8 +1804,8 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
1802
  #pragma unroll
1803
  for (int n = 0; n < ntx; ++n) {
1804
  #pragma unroll
1805
- for (int l = 0; l < mma_C::ne; ++l) {
1806
- sum[(j0/mma_C::J + n)*mma_C::ne + l] += tmp[n][l]*dA[n][l/2];
1807
  }
1808
  }
1809
  }
@@ -2312,36 +2314,36 @@ template<int mmq_x, int mmq_y, int nwarps, bool need_check>
2312
  static __device__ __forceinline__ void mmq_write_back_mma(
2313
  const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max) {
2314
 
2315
- typedef mma_C_I16J8<int> mma_C;
2316
 
2317
  constexpr int granularity = mmq_get_granularity_device(mmq_x);
2318
  constexpr int rows_per_warp = 2 * granularity;
2319
- constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
2320
 
2321
- const int i0 = (threadIdx.y / ntx) * (ntx*mma_C::I);
2322
  #ifdef NEW_MMA_AVAILABLE
2323
- static_assert(nwarps*mma_C::I == mmq_y, "nwarps*mma_C::I != mmq_y");
2324
  #endif // NEW_MMA_AVAILABLE
2325
 
2326
  #pragma unroll
2327
- for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
2328
  #pragma unroll
2329
  for (int n = 0; n < ntx; ++n) {
2330
  #pragma unroll
2331
- for (int l = 0; l < mma_C::ne; ++l) {
2332
- const int j = j0 + (threadIdx.y % ntx) * mma_C::J + mma_C::get_j(l);
2333
 
2334
  if (j > j_max) {
2335
  continue;
2336
  }
2337
 
2338
- const int i = i0 + n*mma_C::I + mma_C::get_i(l);
2339
 
2340
  if (need_check && i > i_max) {
2341
  continue;
2342
  }
2343
 
2344
- dst[j*stride + i] = sum[(j0/mma_C::J + n)*mma_C::ne + l];
2345
  }
2346
  }
2347
  }
 
7
  #include <climits>
8
  #include <cstdint>
9
 
10
+ using namespace ggml_cuda_mma;
11
+
12
  #define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available.
13
  #define MMQ_ITER_K 256
14
  #define MMQ_NWARPS 8
 
649
  static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
650
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
651
 
652
+ typedef tile<16, 8, int> tile_A;
653
+ typedef tile< 8, 8, int> tile_B;
654
+ typedef tile<16, 8, int> tile_C;
655
 
656
  constexpr int granularity = mmq_get_granularity_device(mmq_x);
657
  constexpr int rows_per_warp = 2 * granularity;
658
+ constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
659
 
660
+ y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
661
 
662
  const int * x_qs = (const int *) x;
663
  const float * x_df = (const float *) x_qs + 2*WARP_SIZE;
 
665
  const float * y_df = (const float *) y;
666
  const half2 * y_ds = (const half2 *) y;
667
 
668
+ tile_A A[ntx][WARP_SIZE/QI8_0];
669
+ float dA[ntx][tile_C::ne/2][WARP_SIZE/QI8_0];
670
 
671
  const int i0 = (threadIdx.y/ntx)*rows_per_warp;
672
 
 
676
  for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
677
  const int k0 = k00 + k01;
678
 
679
+ load_ldmatrix(A[n][k01/QI8_0], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
680
  }
681
 
682
  #pragma unroll
683
+ for (int l = 0; l < tile_C::ne/2; ++l) {
684
+ const int i = i0 + n*tile_A::I + tile_C::get_i(2*l);
685
 
686
  #pragma unroll
687
  for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
 
693
  }
694
 
695
  #pragma unroll
696
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
697
  #pragma unroll
698
  for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
699
+ tile_B B;
700
+ float dB[tile_C::ne/2];
701
 
702
+ load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix
703
 
704
  #pragma unroll
705
+ for (int l = 0; l < tile_C::ne/2; ++l) {
706
+ const int j = j0 + tile_C::get_j(l);
707
 
708
  if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) {
709
  dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
 
714
 
715
  #pragma unroll
716
  for (int n = 0; n < ntx; ++n) {
717
+ tile_C C;
718
+ mma(C, A[n][k01/QI8_0], B);
719
 
720
  #pragma unroll
721
+ for (int l = 0; l < tile_C::ne; ++l) {
722
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l]*dA[n][l/2][k01/QI8_0]*dB[l%2];
723
  }
724
  }
725
  }
 
760
  static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
761
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
762
 
763
+ typedef tile<16, 8, int> tile_A;
764
+ typedef tile< 8, 8, int> tile_B;
765
+ typedef tile<16, 8, int> tile_C;
766
 
767
  constexpr int granularity = mmq_get_granularity_device(mmq_x);
768
  constexpr int rows_per_warp = 2 * granularity;
769
+ constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
770
 
771
+ y += (threadIdx.y % ntx) * (tile_B::J*MMQ_TILE_Y_K);
772
 
773
  const int * x_qs = (const int *) x;
774
  const half2 * x_dm = (const half2 *) x_qs + 2*WARP_SIZE;
775
  const int * y_qs = (const int *) y + 4;
776
  const half2 * y_dm = (const half2 *) y;
777
 
778
+ tile_A A[ntx][WARP_SIZE/QI8_1];
779
+ float2 dmA[ntx][tile_C::ne/2][WARP_SIZE/QI8_1];
780
 
781
  const int i0 = (threadIdx.y/ntx)*rows_per_warp;
782
 
 
786
  for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
787
  const int k0 = k00 + k01;
788
 
789
+ load_ldmatrix(A[n][k01/QI8_1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);
790
  }
791
 
792
  #pragma unroll
793
+ for (int l = 0; l < tile_C::ne/2; ++l) {
794
+ const int i = i0 + n*tile_A::I + tile_C::get_i(2*l);
795
 
796
  #pragma unroll
797
  for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
 
803
  }
804
 
805
  #pragma unroll
806
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
807
  #pragma unroll
808
  for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
809
+ tile_B B;
810
+ float2 dsB[tile_C::ne/2];
811
 
812
+ load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix
813
 
814
  #pragma unroll
815
+ for (int l = 0; l < tile_C::ne/2; ++l) {
816
+ const int j = j0 + tile_C::get_j(l);
817
 
818
  dsB[l] = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]);
819
  }
820
 
821
  #pragma unroll
822
  for (int n = 0; n < ntx; ++n) {
823
+ tile_C C;
824
+ mma(C, A[n][k01/QI8_1], B);
825
 
826
  #pragma unroll
827
+ for (int l = 0; l < tile_C::ne; ++l) {
828
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA[n][l/2][k01/QI8_1].x*dsB[l%2].x*C.x[l];
829
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA[n][l/2][k01/QI8_1].y*dsB[l%2].y;
830
  }
831
  }
832
  }
 
870
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
871
  #ifdef NEW_MMA_AVAILABLE
872
 
873
+ typedef tile<16, 4, int> tile_A;
874
+ typedef tile<16, 8, int> tile_A_8;
875
+ typedef tile< 8, 4, int> tile_B;
876
+ typedef tile<16, 8, int> tile_C;
877
 
878
  constexpr int granularity = mmq_get_granularity_device(mmq_x);
879
  constexpr int rows_per_warp = 2 * granularity;
880
+ constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
881
 
882
+ y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
883
 
884
  const int * x_qs = (const int *) x;
885
  const float * x_df = (const float *) x_qs + WARP_SIZE*2;
886
  const int * y_qs = (const int *) y + 4;
887
  const float * y_df = (const float *) y;
888
 
889
+ const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I);
890
 
891
+ tile_A A[ntx][8];
892
+ float dA[ntx][tile_C::ne/2][8];
893
 
894
  #pragma unroll
895
  for (int n = 0; n < ntx; ++n) {
 
897
  for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
898
  const int k0 = k00 + k01;
899
 
900
+ load_ldmatrix(((tile_A_8 *) A[n])[k01/8], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
901
  }
902
 
903
  #pragma unroll
904
+ for (int l = 0; l < tile_C::ne/2; ++l) {
905
+ const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
906
 
907
  #pragma unroll
908
  for (int k01 = 0; k01 < WARP_SIZE; k01 += 4) {
 
914
  }
915
 
916
  #pragma unroll
917
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
918
  #pragma unroll
919
  for (int k01 = 0; k01 < WARP_SIZE; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) {
920
+ tile_B B[2];
921
+ float dB[tile_C::ne/2];
922
 
923
  // Here load_generic is faster than load_ldmatrix.
924
+ load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K);
925
+ load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + (k01 + tile_B::J), MMQ_TILE_Y_K);
926
 
927
  #pragma unroll
928
+ for (int l = 0; l < tile_C::ne/2; ++l) {
929
+ const int j = j0 + tile_C::get_j(l);
930
 
931
  dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
932
  }
933
 
934
  #pragma unroll
935
  for (int n = 0; n < ntx; ++n) {
936
+ tile_C C[2];
937
+ mma(C[0], A[n][k01/4 + 0], B[0]);
938
+ mma(C[1], A[n][k01/4 + 1], B[1]);
939
 
940
  #pragma unroll
941
+ for (int l = 0; l < tile_C::ne; ++l) {
942
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] += dB[l%2]*(C[0].x[l]*dA[n][l/2][k01/4 + 0] + C[1].x[l]*dA[n][l/2][k01/4 + 1]);
943
  }
944
  }
945
  }
 
1058
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
1059
  #ifdef NEW_MMA_AVAILABLE
1060
 
1061
+ typedef tile<16, 4, int> tile_A;
1062
+ typedef tile<16, 8, int> tile_A_8;
1063
+ typedef tile< 8, 4, int> tile_B;
1064
+ typedef tile<16, 8, int> tile_C;
1065
 
1066
  constexpr int granularity = mmq_get_granularity_device(mmq_x);
1067
  constexpr int rows_per_warp = 2 * granularity;
1068
+ constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
1069
 
1070
+ y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
1071
 
1072
  const int * x_qs = (const int *) x;
1073
  const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE*2;
1074
  const int * y_qs = (const int *) y + 4;
1075
  const half2 * y_ds = (const half2 *) y;
1076
 
1077
+ const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I);
1078
 
1079
+ tile_A A[ntx][8];
1080
+ float dA[ntx][tile_C::ne/2][8];
1081
+ float mA[ntx][tile_C::ne/2][8];
1082
 
1083
  #pragma unroll
1084
  for (int n = 0; n < ntx; ++n) {
 
1086
  for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
1087
  const int k0 = k00 + k01;
1088
 
1089
+ load_ldmatrix(((tile_A_8 *) A[n])[k01/QI8_1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
1090
  }
1091
  }
1092
 
1093
  #pragma unroll
1094
  for (int n = 0; n < ntx; ++n) {
1095
  #pragma unroll
1096
+ for (int l = 0; l < tile_C::ne/2; ++l) {
1097
+ const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
1098
 
1099
  #pragma unroll
1100
  for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1/2) {
 
1109
  }
1110
 
1111
  #pragma unroll
1112
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
1113
+ float2 dB[tile_C::ne/2];
1114
 
1115
  #pragma unroll
1116
+ for (int l = 0; l < tile_C::ne/2; ++l) {
1117
+ const int j = j0 + tile_C::get_j(l);
1118
 
1119
  dB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K]);
1120
  }
1121
 
1122
  #pragma unroll
1123
  for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
1124
+ tile_B B[2];
1125
 
1126
  // Here load_generic is faster than load_ldmatrix.
1127
+ load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K);
1128
+ load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + (k01 + tile_B::J), MMQ_TILE_Y_K);
1129
 
1130
+ tile_C Cm[2];
1131
  if (k01 >= WARP_SIZE * 3/4) {
1132
+ tile_A A1;
1133
  A1.x[0] = 0x01010101;
1134
  A1.x[1] = 0x01010101;
1135
+ mma(Cm[0], A1, B[0]);
1136
+ mma(Cm[1], A1, B[1]);
1137
  }
1138
 
1139
  #pragma unroll
1140
  for (int n = 0; n < ntx; ++n) {
1141
+ tile_C Cd[2];
1142
 
1143
+ mma(Cd[0], A[n][k01/4 + 0], B[0]);
1144
+ mma(Cd[1], A[n][k01/4 + 1], B[1]);
1145
 
1146
  #pragma unroll
1147
+ for (int l = 0; l < tile_C::ne; ++l) {
1148
  float tmp = Cd[0].x[l]*dA[n][l/2][k01/4 + 0] + Cd[1].x[l]*dA[n][l/2][k01/4 + 1];
1149
  if (k01 >= WARP_SIZE * 3/4) {
1150
  tmp -= Cm[0].x[l]*mA[n][l/2][k01/4 + 0] + Cm[1].x[l]*mA[n][l/2][k01/4 + 1];
1151
  }
1152
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*(k01 < WARP_SIZE/2 ? dB[l%2].x : dB[l%2].y);
1153
  }
1154
  }
1155
  }
1156
 
1157
  #pragma unroll
1158
  for (int k01 = 0; k01 < WARP_SIZE * 3/4; k01 += QI8_1) {
1159
+ float2 sB[tile_C::ne/2];
1160
 
1161
  #pragma unroll
1162
+ for (int l = 0; l < tile_C::ne/2; ++l) {
1163
+ const int j = j0 + tile_C::get_j(l);
1164
 
1165
  sB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
1166
  }
 
1168
  #pragma unroll
1169
  for (int n = 0; n < ntx; ++n) {
1170
  #pragma unroll
1171
+ for (int l = 0; l < tile_C::ne; ++l) {
1172
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] -= mA[n][l/2][k01/4 + 0]*sB[l%2].x;
1173
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] -= mA[n][l/2][k01/4 + 1]*sB[l%2].y;
1174
  }
1175
  }
1176
  }
 
1710
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
1711
  #ifdef NEW_MMA_AVAILABLE
1712
 
1713
+ typedef tile<16, 4, int> tile_A;
1714
+ typedef tile< 8, 4, int> tile_B;
1715
+ typedef tile<16, 8, int> tile_C;
1716
 
1717
  constexpr int granularity = mmq_get_granularity_device(mmq_x);
1718
  constexpr int rows_per_warp = 2 * granularity;
1719
+ constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
1720
 
1721
+ y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
1722
 
1723
  const int * x_qs = (const int *) x;
1724
  const float * x_df = (const float *) x_qs + WARP_SIZE*2;
 
1726
  const int * y_qs = (const int *) y + 4;
1727
  const float * y_df = (const float *) y;
1728
 
1729
+ const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I);
1730
 
1731
+ tile_A A[ntx][8];
1732
+ int scA[ntx][tile_C::ne/2][8];
1733
+ float dA[ntx][tile_C::ne/2];
1734
 
1735
  #pragma unroll
1736
  for (int n = 0; n < ntx; ++n) {
 
1738
  for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
1739
  const int k0 = k00 + k01;
1740
 
1741
+ load_ldmatrix(A[n][k01/4 + 0], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + 0), MMQ_MMA_TILE_X_K_Q6_K);
1742
+ load_ldmatrix(A[n][k01/4 + 1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + tile_A::J), MMQ_MMA_TILE_X_K_Q6_K);
1743
  }
1744
 
1745
  #pragma unroll
 
1747
  const int k0 = k00 + k01;
1748
 
1749
  #pragma unroll
1750
+ for (int l = 0; l < tile_C::ne/2; ++l) {
1751
+ const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
1752
 
1753
  const int sc_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + k0/16];
1754
  const int8_t * sc = (const int8_t *) &sc_packed;
 
1761
  }
1762
 
1763
  #pragma unroll
1764
+ for (int l = 0; l < tile_C::ne/2; ++l) {
1765
+ const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
1766
 
1767
  dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q6_K];
1768
  }
1769
  }
1770
 
1771
  #pragma unroll
1772
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
1773
+ float tmp[ntx][tile_C::ne] = {{0.0f}};
1774
 
1775
  #pragma unroll
1776
  for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
1777
+ tile_B B[2];
1778
+ float dB[tile_C::ne/2];
1779
 
1780
  // Here load_generic is faster than load_ldmatrix.
1781
+ load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + 0 + k01, MMQ_TILE_Y_K);
1782
+ load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + tile_B::J + k01, MMQ_TILE_Y_K);
1783
 
1784
  #pragma unroll
1785
+ for (int l = 0; l < tile_C::ne/2; ++l) {
1786
+ const int j = j0 + tile_C::get_j(l);
1787
 
1788
  dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
1789
  }
1790
 
1791
  #pragma unroll
1792
  for (int n = 0; n < ntx; ++n) {
1793
+ tile_C C[2];
1794
+ mma(C[0], A[n][k01/4 + 0], B[0]);
1795
+ mma(C[1], A[n][k01/4 + 1], B[1]);
1796
 
1797
  #pragma unroll
1798
+ for (int l = 0; l < tile_C::ne; ++l) {
1799
  tmp[n][l] += (C[0].x[l]*scA[n][l/2][k01/4 + 0] + C[1].x[l]*scA[n][l/2][k01/4 + 1])*dB[l%2];
1800
  }
1801
  }
 
1804
  #pragma unroll
1805
  for (int n = 0; n < ntx; ++n) {
1806
  #pragma unroll
1807
+ for (int l = 0; l < tile_C::ne; ++l) {
1808
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp[n][l]*dA[n][l/2];
1809
  }
1810
  }
1811
  }
 
2314
  static __device__ __forceinline__ void mmq_write_back_mma(
2315
  const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max) {
2316
 
2317
+ typedef tile<16, 8, int> tile_C;
2318
 
2319
  constexpr int granularity = mmq_get_granularity_device(mmq_x);
2320
  constexpr int rows_per_warp = 2 * granularity;
2321
+ constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
2322
 
2323
+ const int i0 = (threadIdx.y / ntx) * (ntx*tile_C::I);
2324
  #ifdef NEW_MMA_AVAILABLE
2325
+ static_assert(nwarps*tile_C::I == mmq_y, "nwarps*tile_C::I != mmq_y");
2326
  #endif // NEW_MMA_AVAILABLE
2327
 
2328
  #pragma unroll
2329
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
2330
  #pragma unroll
2331
  for (int n = 0; n < ntx; ++n) {
2332
  #pragma unroll
2333
+ for (int l = 0; l < tile_C::ne; ++l) {
2334
+ const int j = j0 + (threadIdx.y % ntx) * tile_C::J + tile_C::get_j(l);
2335
 
2336
  if (j > j_max) {
2337
  continue;
2338
  }
2339
 
2340
+ const int i = i0 + n*tile_C::I + tile_C::get_i(l);
2341
 
2342
  if (need_check && i > i_max) {
2343
  continue;
2344
  }
2345
 
2346
+ dst[j*stride + i] = sum[(j0/tile_C::J + n)*tile_C::ne + l];
2347
  }
2348
  }
2349
  }