JohannesGaessler commited on
Commit
78a5b67
·
1 Parent(s): 9f87c2f

CUDA: use tensor cores for MMQ (llama/7676)

Browse files

* CUDA: int8 tensor cores for MMQ (legacy quants)

* fix out-of-bounds writes

* __builtin_assume -> GGML_CUDA_ASSUME

* fix writeback returning too early

ggml-cuda/common.cuh CHANGED
@@ -139,6 +139,7 @@
139
  #define CC_PASCAL 600
140
  #define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
141
  #define CC_VOLTA 700
 
142
  #define CC_AMPERE 800
143
  #define CC_OFFSET_AMD 1000000
144
  #define CC_RDNA1 (CC_OFFSET_AMD + 1010)
@@ -326,9 +327,17 @@ static __device__ __forceinline__ half2 __shfl_xor(half2 var, int laneMask, int
326
  #endif // defined(__HIP_PLATFORM_AMD__) && HIP_VERSION < 50600000
327
  #endif // defined(GGML_USE_HIPBLAS)
328
 
329
- #define FP16_AVAILABLE (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
 
 
330
 
331
- #define FP16_MMA_AVAILABLE !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
 
 
 
 
 
 
332
 
333
  static bool fast_fp16_available(const int cc) {
334
  return cc >= CC_PASCAL && cc != 610;
@@ -338,6 +347,10 @@ static bool fp16_mma_available(const int cc) {
338
  return cc < CC_OFFSET_AMD && cc >= CC_VOLTA;
339
  }
340
 
 
 
 
 
341
  [[noreturn]]
342
  static __device__ void no_device_code(
343
  const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) {
@@ -379,7 +392,7 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
379
  }
380
 
381
  static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
382
- #if FP16_AVAILABLE
383
 
384
  #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
385
  #pragma unroll
@@ -412,7 +425,7 @@ static __device__ __forceinline__ float warp_reduce_max(float x) {
412
  }
413
 
414
  static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) {
415
- #if FP16_AVAILABLE
416
 
417
  #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX
418
  return __float2half(fmaxf(__half2float(a), __half2float(b)));
 
139
  #define CC_PASCAL 600
140
  #define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
141
  #define CC_VOLTA 700
142
+ #define CC_TURING 750
143
  #define CC_AMPERE 800
144
  #define CC_OFFSET_AMD 1000000
145
  #define CC_RDNA1 (CC_OFFSET_AMD + 1010)
 
327
  #endif // defined(__HIP_PLATFORM_AMD__) && HIP_VERSION < 50600000
328
  #endif // defined(GGML_USE_HIPBLAS)
329
 
330
+ #if (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
331
+ #define FP16_AVAILABLE
332
+ #endif // (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
333
 
334
+ #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
335
+ #define FP16_MMA_AVAILABLE
336
+ #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
337
+
338
+ #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING
339
+ #define INT8_MMA_AVAILABLE
340
+ #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING
341
 
342
  static bool fast_fp16_available(const int cc) {
343
  return cc >= CC_PASCAL && cc != 610;
 
347
  return cc < CC_OFFSET_AMD && cc >= CC_VOLTA;
348
  }
349
 
350
+ static bool int8_mma_available(const int cc) {
351
+ return cc < CC_OFFSET_AMD && cc >= CC_TURING;
352
+ }
353
+
354
  [[noreturn]]
355
  static __device__ void no_device_code(
356
  const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) {
 
392
  }
393
 
394
  static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
395
+ #ifdef FP16_AVAILABLE
396
 
397
  #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
398
  #pragma unroll
 
425
  }
426
 
427
  static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) {
428
+ #ifdef FP16_AVAILABLE
429
 
430
  #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX
431
  return __float2half(fmaxf(__half2float(a), __half2float(b)));
ggml-cuda/fattn-common.cuh CHANGED
@@ -74,7 +74,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
74
 
75
  const int sumi = __dp4a(v, u, 0);
76
 
77
- #if FP16_AVAILABLE
78
  if (std::is_same<T, half>::value) {
79
  const half2 * Q_ds = (const half2 *) Q_ds_v;
80
 
@@ -122,7 +122,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
122
 
123
  const int sumi = __dp4a(v, u, 0);
124
 
125
- #if FP16_AVAILABLE
126
  if (std::is_same<T, half>::value) {
127
  const half2 * Q_ds = (const half2 *) Q_ds_v;
128
 
@@ -181,7 +181,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
181
 
182
  const int sumi = __dp4a(v, u, 0);
183
 
184
- #if FP16_AVAILABLE
185
  if (std::is_same<T, half>::value) {
186
  const half2 * Q_ds = (const half2 *) Q_ds_v;
187
 
@@ -236,7 +236,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
236
 
237
  const int sumi = __dp4a(v, u, 0);
238
 
239
- #if FP16_AVAILABLE
240
  if (std::is_same<T, half>::value) {
241
  const half2 * Q_ds = (const half2 *) Q_ds_v;
242
 
@@ -314,7 +314,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16(
314
  GGML_UNUSED(Q_q8);
315
  GGML_UNUSED(Q_ds_v);
316
 
317
- #if FP16_AVAILABLE
318
  if (std::is_same<T, half>::value) {
319
  const half2 * Q_h2 = (const half2 *) Q_v;
320
 
@@ -407,7 +407,7 @@ static __device__ __forceinline__ T dequantize_1_q4_0(const void * __restrict__
407
  const int q0 = x[ib].qs[iqs];
408
  const int q = ((q0 >> (4*shift)) & 0x0F) - 8;
409
 
410
- #if FP16_AVAILABLE
411
  if (std::is_same<T, half>::value) {
412
  return ((half) d)*((half) q);
413
  }
@@ -428,7 +428,7 @@ static __device__ __forceinline__ T dequantize_1_q4_1(const void * __restrict__
428
  const int q0 = x[ib].qs[iqs];
429
  const int q = ((q0 >> (4*shift)) & 0x0F);
430
 
431
- #if FP16_AVAILABLE
432
  if (std::is_same<T, half>::value) {
433
  return __low2half(dm)*((half) q) + __high2half(dm);
434
  }
@@ -453,7 +453,7 @@ static __device__ __forceinline__ T dequantize_1_q5_0(const void * __restrict__
453
  const int qh = ((qh0 >> idq) << 4) & 0x10;
454
  const int q = (ql | qh) - 16;
455
 
456
- #if FP16_AVAILABLE
457
  if (std::is_same<T, half>::value) {
458
  return ((half) d)*((half) q);
459
  }
@@ -478,7 +478,7 @@ static __device__ __forceinline__ T dequantize_1_q5_1(const void * __restrict__
478
  const int qh = ((qh0 >> idq) << 4) & 0x10;
479
  const int q = (ql | qh);
480
 
481
- #if FP16_AVAILABLE
482
  if (std::is_same<T, half>::value) {
483
  return __low2half(dm)*((half) q) + __high2half(dm);
484
  }
@@ -497,7 +497,7 @@ static __device__ __forceinline__ T dequantize_1_q8_0(const void * __restrict__
497
  const T d = x[ib].d;
498
  const int q = x[ib].qs[iqs];
499
 
500
- #if FP16_AVAILABLE
501
  if (std::is_same<T, half>::value) {
502
  return ((half) d)*((half) q);
503
  }
 
74
 
75
  const int sumi = __dp4a(v, u, 0);
76
 
77
+ #ifdef FP16_AVAILABLE
78
  if (std::is_same<T, half>::value) {
79
  const half2 * Q_ds = (const half2 *) Q_ds_v;
80
 
 
122
 
123
  const int sumi = __dp4a(v, u, 0);
124
 
125
+ #ifdef FP16_AVAILABLE
126
  if (std::is_same<T, half>::value) {
127
  const half2 * Q_ds = (const half2 *) Q_ds_v;
128
 
 
181
 
182
  const int sumi = __dp4a(v, u, 0);
183
 
184
+ #ifdef FP16_AVAILABLE
185
  if (std::is_same<T, half>::value) {
186
  const half2 * Q_ds = (const half2 *) Q_ds_v;
187
 
 
236
 
237
  const int sumi = __dp4a(v, u, 0);
238
 
239
+ #ifdef FP16_AVAILABLE
240
  if (std::is_same<T, half>::value) {
241
  const half2 * Q_ds = (const half2 *) Q_ds_v;
242
 
 
314
  GGML_UNUSED(Q_q8);
315
  GGML_UNUSED(Q_ds_v);
316
 
317
+ #ifdef FP16_AVAILABLE
318
  if (std::is_same<T, half>::value) {
319
  const half2 * Q_h2 = (const half2 *) Q_v;
320
 
 
407
  const int q0 = x[ib].qs[iqs];
408
  const int q = ((q0 >> (4*shift)) & 0x0F) - 8;
409
 
410
+ #ifdef FP16_AVAILABLE
411
  if (std::is_same<T, half>::value) {
412
  return ((half) d)*((half) q);
413
  }
 
428
  const int q0 = x[ib].qs[iqs];
429
  const int q = ((q0 >> (4*shift)) & 0x0F);
430
 
431
+ #ifdef FP16_AVAILABLE
432
  if (std::is_same<T, half>::value) {
433
  return __low2half(dm)*((half) q) + __high2half(dm);
434
  }
 
453
  const int qh = ((qh0 >> idq) << 4) & 0x10;
454
  const int q = (ql | qh) - 16;
455
 
456
+ #ifdef FP16_AVAILABLE
457
  if (std::is_same<T, half>::value) {
458
  return ((half) d)*((half) q);
459
  }
 
478
  const int qh = ((qh0 >> idq) << 4) & 0x10;
479
  const int q = (ql | qh);
480
 
481
+ #ifdef FP16_AVAILABLE
482
  if (std::is_same<T, half>::value) {
483
  return __low2half(dm)*((half) q) + __high2half(dm);
484
  }
 
497
  const T d = x[ib].d;
498
  const int q = x[ib].qs[iqs];
499
 
500
+ #ifdef FP16_AVAILABLE
501
  if (std::is_same<T, half>::value) {
502
  return ((half) d)*((half) q);
503
  }
ggml-cuda/fattn-tile-f16.cu CHANGED
@@ -43,7 +43,7 @@ static __global__ void flash_attn_tile_ext_f16(
43
  const int ne1,
44
  const int ne2,
45
  const int ne3) {
46
- #if FP16_AVAILABLE
47
  //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
48
 
49
  const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
 
43
  const int ne1,
44
  const int ne2,
45
  const int ne3) {
46
+ #ifdef FP16_AVAILABLE
47
  //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
48
 
49
  const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
ggml-cuda/fattn-vec-f16.cuh CHANGED
@@ -40,7 +40,7 @@ static __global__ void flash_attn_vec_ext_f16(
40
  const int ne1,
41
  const int ne2,
42
  const int ne3) {
43
- #if FP16_AVAILABLE
44
  //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
45
 
46
  constexpr vec_dot_KQ_f16_t vec_dot_KQ = get_vec_dot_KQ_f16<D>(type_K);
 
40
  const int ne1,
41
  const int ne2,
42
  const int ne3) {
43
+ #ifdef FP16_AVAILABLE
44
  //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
45
 
46
  constexpr vec_dot_KQ_f16_t vec_dot_KQ = get_vec_dot_KQ_f16<D>(type_K);
ggml-cuda/fattn-wmma-f16.cuh CHANGED
@@ -1,9 +1,9 @@
1
  #include "common.cuh"
2
  #include "fattn-common.cuh"
3
 
4
- #if FP16_MMA_AVAILABLE
5
  #include <mma.h>
6
- #endif
7
 
8
  // D == head size, VKQ_stride == num VKQ rows calculated in parallel:
9
  template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t>
@@ -45,7 +45,7 @@ static __global__ void flash_attn_ext_f16(
45
  const int ne1,
46
  const int ne2,
47
  const int ne3) {
48
- #if FP16_MMA_AVAILABLE
49
  //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
50
 
51
  const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on.
 
1
  #include "common.cuh"
2
  #include "fattn-common.cuh"
3
 
4
+ #ifdef FP16_MMA_AVAILABLE
5
  #include <mma.h>
6
+ #endif // FP16_MMA_AVAILABLE
7
 
8
  // D == head size, VKQ_stride == num VKQ rows calculated in parallel:
9
  template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t>
 
45
  const int ne1,
46
  const int ne2,
47
  const int ne3) {
48
+ #ifdef FP16_MMA_AVAILABLE
49
  //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
50
 
51
  const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on.
ggml-cuda/mma.cuh ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "common.cuh"
2
+
3
+ struct mma_int_A_I16K8 {
4
+ static constexpr int I = 16;
5
+ static constexpr int K = 8;
6
+ static constexpr int ne = 4;
7
+
8
+ int x[ne] = {0};
9
+
10
+ static __device__ __forceinline__ int get_i(const int l) {
11
+ const int ret = (l%2) * (I/2) + threadIdx.x / (K/2);
12
+ GGML_CUDA_ASSUME(ret >= 0);
13
+ GGML_CUDA_ASSUME(ret < I);
14
+ return ret;
15
+ }
16
+
17
+ static __device__ __forceinline__ int get_k(const int l) {
18
+ const int ret = (l/2) * (K/2) + threadIdx.x % (K/2);
19
+ GGML_CUDA_ASSUME(ret >= 0);
20
+ GGML_CUDA_ASSUME(ret < K);
21
+ return ret;
22
+ }
23
+ };
24
+
25
+ struct mma_int_B_J8K8 {
26
+ static constexpr int J = 8;
27
+ static constexpr int K = 8;
28
+ static constexpr int ne = 2;
29
+
30
+ int x[ne] = {0};
31
+
32
+ static __device__ __forceinline__ int get_j(const int /* l */) {
33
+ const int ret = threadIdx.x / (K/2);
34
+ GGML_CUDA_ASSUME(ret >= 0);
35
+ GGML_CUDA_ASSUME(ret < J);
36
+ return ret;
37
+ }
38
+
39
+ static __device__ __forceinline__ int get_k(const int l) {
40
+ const int ret = l * (K/2) + threadIdx.x % (K/2);
41
+ GGML_CUDA_ASSUME(ret >= 0);
42
+ GGML_CUDA_ASSUME(ret < K);
43
+ return ret;
44
+ }
45
+ };
46
+
47
+ struct mma_int_C_I16J8 {
48
+ static constexpr int I = 16;
49
+ static constexpr int J = 8;
50
+ static constexpr int ne = 4;
51
+
52
+ int x[ne] = {0};
53
+
54
+ static __device__ __forceinline__ int get_i(const int l) {
55
+ const int ret = (l/2) * (I/2) + threadIdx.x / (J/2);
56
+ GGML_CUDA_ASSUME(ret >= 0);
57
+ GGML_CUDA_ASSUME(ret < I);
58
+ return ret;
59
+ }
60
+
61
+ static __device__ __forceinline__ int get_j(const int l) {
62
+ const int ret = 2 * (threadIdx.x % (J/2)) + l%2;
63
+ GGML_CUDA_ASSUME(ret >= 0);
64
+ GGML_CUDA_ASSUME(ret < J);
65
+ return ret;
66
+ }
67
+
68
+ __device__ __forceinline__ void mma_K8(const mma_int_A_I16K8 & mma_A, const mma_int_B_J8K8 & mma_B) {
69
+ #ifdef INT8_MMA_AVAILABLE
70
+ #if __CUDA_ARCH__ >= CC_AMPERE
71
+ asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
72
+ : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
73
+ : "r"(mma_A.x[0]), "r"(mma_A.x[1]), "r"(mma_A.x[2]), "r"(mma_A.x[3]), "r"(mma_B.x[0]), "r"(mma_B.x[1]));
74
+ #else
75
+ // On Turing m16n8k32 mma is not available, use 4x m8n8k16 mma instead:
76
+ asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
77
+ : "+r"(x[0]), "+r"(x[1])
78
+ : "r"(mma_A.x[0]), "r"(mma_B.x[0]));
79
+ asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
80
+ : "+r"(x[2]), "+r"(x[3])
81
+ : "r"(mma_A.x[1]), "r"(mma_B.x[0]));
82
+ asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
83
+ : "+r"(x[0]), "+r"(x[1])
84
+ : "r"(mma_A.x[2]), "r"(mma_B.x[1]));
85
+ asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
86
+ : "+r"(x[2]), "+r"(x[3])
87
+ : "r"(mma_A.x[3]), "r"(mma_B.x[1]));
88
+ #endif // __CUDA_ARCH__ >= CC_AMPERE
89
+ #else
90
+ GGML_UNUSED(mma_A);
91
+ GGML_UNUSED(mma_B);
92
+ NO_DEVICE_CODE;
93
+ #endif // INT8_MMA_AVAILABLE
94
+ }
95
+ };
ggml-cuda/mmq.cuh CHANGED
@@ -2,6 +2,7 @@
2
 
3
  #include "common.cuh"
4
  #include "vecdotq.cuh"
 
5
 
6
  #include <climits>
7
  #include <cstdint>
@@ -14,6 +15,7 @@ typedef void (*load_tiles_mmq_t)(
14
  typedef void (*vec_dot_mmq_t)(
15
  const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
16
  const int * __restrict__ y, float * __restrict__ sum, const int & k0);
 
17
 
18
  struct block_q8_1_mmq {
19
  half2 ds[4];
@@ -141,15 +143,15 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
141
  }
142
 
143
  template <int mmq_x, int mmq_y, int nwarps>
144
- static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mul_mat(
145
  const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
146
  const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
147
 
148
  GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
149
 
150
- const float * x_dmf = (const float *) x_dm;
151
- const int * y_qs = (const int *) y + 4;
152
- const half2 * y_ds = (const half2 *) y;
153
 
154
  #pragma unroll
155
  for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
@@ -170,12 +172,76 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mul_mat(
170
  }
171
 
172
  sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>
173
- (&x_ql[i*(WARP_SIZE + 1) + k0], u, x_dmf[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0],
174
  y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
175
  }
176
  }
177
  }
178
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_1(
180
  const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
181
  int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
@@ -215,7 +281,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
215
  }
216
 
217
  template <int mmq_x, int mmq_y, int nwarps>
218
- static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mul_mat(
219
  const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
220
  const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
221
 
@@ -249,6 +315,70 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mul_mat(
249
  }
250
  }
251
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_0(
253
  const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
254
  int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
@@ -308,7 +438,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
308
  }
309
 
310
  template <int mmq_x, int mmq_y, int nwarps>
311
- static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mul_mat(
312
  const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
313
  const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
314
 
@@ -343,6 +473,68 @@ static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mul_mat(
343
  }
344
  }
345
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
346
 
347
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_1(
348
  const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
@@ -400,7 +592,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
400
  }
401
 
402
  template <int mmq_x, int mmq_y, int nwarps>
403
- static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mul_mat(
404
  const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
405
  const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
406
 
@@ -434,6 +626,69 @@ static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mul_mat(
434
  }
435
  }
436
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
437
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
438
  const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
439
  int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
@@ -475,7 +730,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
475
  }
476
 
477
  template <int mmq_x, int mmq_y, int nwarps>
478
- static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mul_mat(
479
  const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
480
  const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
481
 
@@ -500,6 +755,69 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mul_mat(
500
  }
501
  }
502
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
503
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q2_K(
504
  const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
505
  int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
@@ -989,6 +1307,57 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mul_mat(
989
  }
990
  }
991
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
992
  // -------------------------------------------------------------------------------------------------------------------------------------
993
 
994
  template <int mmq_x, int mmq_y, int nwarps, bool need_check, ggml_type type>
@@ -998,35 +1367,65 @@ template <int mmq_x, int mmq_y, int nwarps, bool need_check>
998
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_0> {
999
  static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ;
1000
  static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0<mmq_y, nwarps, need_check>;
1001
- static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_0_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
 
 
 
 
 
 
1002
  };
1003
 
1004
  template <int mmq_x, int mmq_y, int nwarps, bool need_check>
1005
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_1> {
1006
  static constexpr int vdr = VDR_Q4_1_Q8_1_MMQ;
1007
  static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1<mmq_y, nwarps, need_check>;
1008
- static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_1_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
 
 
 
 
 
 
1009
  };
1010
 
1011
  template <int mmq_x, int mmq_y, int nwarps, bool need_check>
1012
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_0> {
1013
  static constexpr int vdr = VDR_Q5_0_Q8_1_MMQ;
1014
  static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0<mmq_y, nwarps, need_check>;
1015
- static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_0_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
 
 
 
 
 
 
1016
  };
1017
 
1018
  template <int mmq_x, int mmq_y, int nwarps, bool need_check>
1019
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_1> {
1020
  static constexpr int vdr = VDR_Q5_1_Q8_1_MMQ;
1021
  static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1<mmq_y, nwarps, need_check>;
1022
- static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_1_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
 
 
 
 
 
 
1023
  };
1024
 
1025
  template <int mmq_x, int mmq_y, int nwarps, bool need_check>
1026
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q8_0> {
1027
  static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ;
1028
  static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0<mmq_y, nwarps, need_check>;
1029
- static constexpr vec_dot_mmq_t vec_dot = vec_dot_q8_0_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
 
 
 
 
 
 
1030
  };
1031
 
1032
  template <int mmq_x, int mmq_y, int nwarps, bool need_check>
@@ -1034,6 +1433,7 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q2_K> {
1034
  static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ;
1035
  static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K<mmq_y, nwarps, need_check>;
1036
  static constexpr vec_dot_mmq_t vec_dot = vec_dot_q2_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
 
1037
  };
1038
 
1039
  template <int mmq_x, int mmq_y, int nwarps, bool need_check>
@@ -1041,6 +1441,7 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q3_K> {
1041
  static constexpr int vdr = VDR_Q3_K_Q8_1_MMQ;
1042
  static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K<mmq_y, nwarps, need_check>;
1043
  static constexpr vec_dot_mmq_t vec_dot = vec_dot_q3_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
 
1044
  };
1045
 
1046
  template <int mmq_x, int mmq_y, int nwarps, bool need_check>
@@ -1048,6 +1449,7 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_K> {
1048
  static constexpr int vdr = VDR_Q4_K_Q8_1_MMQ;
1049
  static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K<mmq_y, nwarps, need_check>;
1050
  static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
 
1051
  };
1052
 
1053
  template <int mmq_x, int mmq_y, int nwarps, bool need_check>
@@ -1055,6 +1457,7 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_K> {
1055
  static constexpr int vdr = VDR_Q5_K_Q8_1_MMQ;
1056
  static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K<mmq_y, nwarps, need_check>;
1057
  static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
 
1058
  };
1059
 
1060
  template <int mmq_x, int mmq_y, int nwarps, bool need_check>
@@ -1062,6 +1465,7 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q6_K> {
1062
  static constexpr int vdr = VDR_Q6_K_Q8_1_MMQ;
1063
  static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K<mmq_y, nwarps, need_check>;
1064
  static constexpr vec_dot_mmq_t vec_dot = vec_dot_q6_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
 
1065
  };
1066
 
1067
  static int mmq_need_sum(const ggml_type type_x) {
@@ -1118,6 +1522,7 @@ static __global__ void mul_mat_q(
1118
  constexpr int vdr = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vdr;
1119
  constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::load_tiles;
1120
  constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vec_dot;
 
1121
 
1122
  constexpr tile_x_sizes txs = get_tile_x_sizes_device<mmq_y>(type);
1123
 
@@ -1137,7 +1542,7 @@ static __global__ void mul_mat_q(
1137
 
1138
  const int * y = (const int *) yc + blockIdx.y*(mmq_x*sizeof(block_q8_1_mmq)/sizeof(int));
1139
 
1140
- float sum[(mmq_x/nwarps) * (mmq_y/WARP_SIZE)] = {0.0f};
1141
 
1142
  for (int kb0 = 0; kb0 < blocks_per_row_x; kb0 += blocks_per_warp) {
1143
 
@@ -1164,25 +1569,7 @@ static __global__ void mul_mat_q(
1164
  }
1165
  }
1166
 
1167
- #pragma unroll
1168
- for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
1169
- const int j = blockIdx.y*mmq_x + j0 + threadIdx.y;
1170
-
1171
- if (j >= ne1) {
1172
- return;
1173
- }
1174
-
1175
- #pragma unroll
1176
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
1177
- const int i = blockIdx.x*mmq_y + i0 + threadIdx.x;
1178
-
1179
- if (need_check && i >= ne0) {
1180
- continue;
1181
- }
1182
-
1183
- dst[j*ne0 + i] = sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
1184
- }
1185
- }
1186
  }
1187
 
1188
  struct mmq_args {
@@ -1256,10 +1643,10 @@ void mul_mat_q_case(const mmq_args & args, cudaStream_t stream) {
1256
  launch_mul_mat_q<type, 8, 4>(args, stream);
1257
  break;
1258
  case 16:
1259
- launch_mul_mat_q<type, 16, 8>(args, stream);
1260
  break;
1261
  case 24:
1262
- launch_mul_mat_q<type, 24, 8>(args, stream);
1263
  break;
1264
  case 32:
1265
  launch_mul_mat_q<type, 32, 8>(args, stream);
 
2
 
3
  #include "common.cuh"
4
  #include "vecdotq.cuh"
5
+ #include "mma.cuh"
6
 
7
  #include <climits>
8
  #include <cstdint>
 
15
  typedef void (*vec_dot_mmq_t)(
16
  const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
17
  const int * __restrict__ y, float * __restrict__ sum, const int & k0);
18
+ typedef void (*mmq_write_back_t)(const float * __restrict__ sum, float * __restrict__ dst, const int & ne0, const int & ne1);
19
 
20
  struct block_q8_1_mmq {
21
  half2 ds[4];
 
143
  }
144
 
145
  template <int mmq_x, int mmq_y, int nwarps>
146
+ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
147
  const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
148
  const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
149
 
150
  GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
151
 
152
+ const float * x_df = (const float *) x_dm;
153
+ const int * y_qs = (const int *) y + 4;
154
+ const half2 * y_ds = (const half2 *) y;
155
 
156
  #pragma unroll
157
  for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
 
172
  }
173
 
174
  sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>
175
+ (&x_ql[i*(WARP_SIZE + 1) + k0], u, x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0],
176
  y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
177
  }
178
  }
179
  }
180
 
181
+ template <int mmq_x, int mmq_y, int nwarps>
182
+ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mma(
183
+ const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
184
+ const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
185
+
186
+ GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
187
+
188
+ typedef mma_int_A_I16K8 mma_A;
189
+ typedef mma_int_B_J8K8 mma_B;
190
+ typedef mma_int_C_I16J8 mma_C;
191
+
192
+ const float * x_df = (const float *) x_dm;
193
+ const int * y_qs = (const int *) y + 4;
194
+ const half2 * y_ds = (const half2 *) y;
195
+
196
+ mma_A A;
197
+ float dA[mma_C::ne/2];
198
+
199
+ const int i0 = threadIdx.y*mma_A::I;
200
+ static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
201
+
202
+ #pragma unroll
203
+ for (int l = 0; l < mma_A::ne; ++l) {
204
+ const int i = i0 + mma_A::get_i(l);
205
+ const int k = k0 + mma_A::get_k(l) % QI4_0;
206
+ const int shift = 4*(mma_A::get_k(l) / QI4_0);
207
+
208
+ A.x[l] = __vsubss4((x_ql[i*(WARP_SIZE + 1) + k] >> shift) & 0x0F0F0F0F, 0x08080808);
209
+ }
210
+ #pragma unroll
211
+ for (int l = 0; l < mma_C::ne/2; ++l) {
212
+ const int i = i0 + mma_C::get_i(2*l);
213
+
214
+ dA[l] = x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0];
215
+ }
216
+
217
+ for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
218
+ mma_C C;
219
+ mma_B B;
220
+ half2 dsB[mma_C::ne/2];
221
+
222
+ #pragma unroll
223
+ for (int l = 0; l < mma_B::ne; ++l) {
224
+ const int j = j0 + mma_B::get_j(l);
225
+ const int k = (2*k0 + mma_B::get_k(l)) % WARP_SIZE;
226
+
227
+ B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
228
+ }
229
+ #pragma unroll
230
+ for (int l = 0; l < mma_C::ne/2; ++l) {
231
+ const int j = j0 + mma_C::get_j(l);
232
+
233
+ dsB[l] = y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)];
234
+ }
235
+
236
+ C.mma_K8(A, B);
237
+
238
+ #pragma unroll
239
+ for (int l = 0; l < mma_C::ne; ++l) {
240
+ sum[(j0/B.J)*C.ne + l] += dA[l/2]*__low2float(dsB[l%2])*C.x[l];
241
+ }
242
+ }
243
+ }
244
+
245
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_1(
246
  const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
247
  int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
 
281
  }
282
 
283
  template <int mmq_x, int mmq_y, int nwarps>
284
+ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
285
  const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
286
  const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
287
 
 
315
  }
316
  }
317
 
318
+ template <int mmq_x, int mmq_y, int nwarps>
319
+ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mma(
320
+ const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
321
+ const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
322
+
323
+ GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
324
+
325
+ typedef mma_int_A_I16K8 mma_A;
326
+ typedef mma_int_B_J8K8 mma_B;
327
+ typedef mma_int_C_I16J8 mma_C;
328
+
329
+ const int * y_qs = (const int *) y + 4;
330
+ const half2 * y_ds = (const half2 *) y;
331
+
332
+ mma_A A;
333
+ half2 dmA[mma_C::ne/2];
334
+
335
+ const int i0 = threadIdx.y*mma_A::I;
336
+ static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
337
+
338
+ #pragma unroll
339
+ for (int l = 0; l < mma_A::ne; ++l) {
340
+ const int i = i0 + mma_A::get_i(l);
341
+ const int k = k0 + mma_A::get_k(l) % QI4_0;
342
+ const int shift = 4*(mma_A::get_k(l) / QI4_0);
343
+
344
+ A.x[l] = (x_ql[i*(WARP_SIZE + 1) + k] >> shift) & 0x0F0F0F0F;
345
+ }
346
+ #pragma unroll
347
+ for (int l = 0; l < mma_C::ne/2; ++l) {
348
+ const int i = i0 + mma_C::get_i(2*l);
349
+
350
+ dmA[l] = x_dm[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0];
351
+ }
352
+
353
+ for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
354
+ mma_C C;
355
+ mma_B B;
356
+ half2 dsB[mma_C::ne/2];
357
+
358
+ #pragma unroll
359
+ for (int l = 0; l < mma_B::ne; ++l) {
360
+ const int j = j0 + mma_B::get_j(l);
361
+ const int k = (2*k0 + mma_B::get_k(l)) % WARP_SIZE;
362
+
363
+ B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
364
+ }
365
+ #pragma unroll
366
+ for (int l = 0; l < mma_C::ne/2; ++l) {
367
+ const int j = j0 + mma_C::get_j(l);
368
+
369
+ dsB[l] = y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)];
370
+ }
371
+
372
+ C.mma_K8(A, B);
373
+
374
+ #pragma unroll
375
+ for (int l = 0; l < mma_C::ne; ++l) {
376
+ const half2 dmA_dsB = dmA[l/2]*dsB[l%2];
377
+ sum[(j0/B.J)*C.ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB);
378
+ }
379
+ }
380
+ }
381
+
382
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_0(
383
  const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
384
  int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
 
438
  }
439
 
440
  template <int mmq_x, int mmq_y, int nwarps>
441
+ static __device__ __forceinline__ void vec_dot_q5_0_q8_1_dp4a(
442
  const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
443
  const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
444
 
 
473
  }
474
  }
475
 
476
+ template <int mmq_x, int mmq_y, int nwarps>
477
+ static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mma(
478
+ const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
479
+ const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
480
+
481
+ GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
482
+
483
+ typedef mma_int_A_I16K8 mma_A;
484
+ typedef mma_int_B_J8K8 mma_B;
485
+ typedef mma_int_C_I16J8 mma_C;
486
+
487
+ const float * x_df = (const float *) x_dm;
488
+ const int * y_qs = (const int *) y + 4;
489
+ const float * y_df = (const float *) y;
490
+
491
+ mma_A A;
492
+ float dA[mma_C::ne/2];
493
+
494
+ const int i0 = threadIdx.y*mma_A::I;
495
+ static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
496
+
497
+ #pragma unroll
498
+ for (int l = 0; l < mma_A::ne; ++l) {
499
+ const int i = i0 + mma_A::get_i(l);
500
+ const int k = 2*(k0 + mma_A::get_k(l) % QI5_0) + mma_A::get_k(l) / QI5_0;
501
+
502
+ A.x[l] = x_ql[i*(2*WARP_SIZE + 1) + k];
503
+ }
504
+ #pragma unroll
505
+ for (int l = 0; l < mma_C::ne/2; ++l) {
506
+ const int i = i0 + mma_C::get_i(2*l);
507
+
508
+ dA[l] = x_df[i*(WARP_SIZE/QI5_0) + i/QI5_0 + k0/QI5_0];
509
+ }
510
+
511
+ for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
512
+ mma_C C;
513
+ mma_B B;
514
+ float dB[mma_C::ne/2];
515
+
516
+ #pragma unroll
517
+ for (int l = 0; l < mma_B::ne; ++l) {
518
+ const int j = j0 + mma_B::get_j(l);
519
+ const int k = (2*k0 + mma_B::get_k(l)) % WARP_SIZE;
520
+
521
+ B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
522
+ }
523
+ #pragma unroll
524
+ for (int l = 0; l < mma_C::ne/2; ++l) {
525
+ const int j = j0 + mma_C::get_j(l);
526
+
527
+ dB[l] = y_df[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)];
528
+ }
529
+
530
+ C.mma_K8(A, B);
531
+
532
+ #pragma unroll
533
+ for (int l = 0; l < mma_C::ne; ++l) {
534
+ sum[(j0/B.J)*C.ne + l] += dA[l/2]*dB[l%2]*C.x[l];
535
+ }
536
+ }
537
+ }
538
 
539
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_1(
540
  const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
 
592
  }
593
 
594
  template <int mmq_x, int mmq_y, int nwarps>
595
+ static __device__ __forceinline__ void vec_dot_q5_1_q8_1_dp4a(
596
  const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
597
  const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
598
 
 
626
  }
627
  }
628
 
629
+ template <int mmq_x, int mmq_y, int nwarps>
630
+ static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mma(
631
+ const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
632
+ const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
633
+
634
+ GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
635
+
636
+ typedef mma_int_A_I16K8 mma_A;
637
+ typedef mma_int_B_J8K8 mma_B;
638
+ typedef mma_int_C_I16J8 mma_C;
639
+
640
+ const int * y_qs = (const int *) y + 4;
641
+ const half2 * y_ds = (const half2 *) y;
642
+
643
+ mma_A A;
644
+ half2 dmA[mma_C::ne/2];
645
+
646
+ const int i0 = threadIdx.y*mma_A::I;
647
+ static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
648
+
649
+ #pragma unroll
650
+ for (int l = 0; l < mma_A::ne; ++l) {
651
+ const int i = i0 + mma_A::get_i(l);
652
+ const int k = 2*(k0 + mma_A::get_k(l) % QI5_1) + mma_A::get_k(l) / QI5_1;
653
+
654
+ A.x[l] = x_ql[i*(2*WARP_SIZE + 1) + k];
655
+ }
656
+ #pragma unroll
657
+ for (int l = 0; l < mma_C::ne/2; ++l) {
658
+ const int i = i0 + mma_C::get_i(2*l);
659
+
660
+ dmA[l] = x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + k0/QI5_1];
661
+ }
662
+
663
+ for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
664
+ mma_C C;
665
+ mma_B B;
666
+ half2 dsB[mma_C::ne/2];
667
+
668
+ #pragma unroll
669
+ for (int l = 0; l < mma_B::ne; ++l) {
670
+ const int j = j0 + mma_B::get_j(l);
671
+ const int k = (2*k0 + mma_B::get_k(l)) % WARP_SIZE;
672
+
673
+ B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
674
+ }
675
+ #pragma unroll
676
+ for (int l = 0; l < mma_C::ne/2; ++l) {
677
+ const int j = j0 + mma_C::get_j(l);
678
+
679
+ dsB[l] = y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)];
680
+ }
681
+
682
+ C.mma_K8(A, B);
683
+
684
+ #pragma unroll
685
+ for (int l = 0; l < mma_C::ne; ++l) {
686
+ const half2 dmA_dsB = dmA[l/2]*dsB[l%2];
687
+ sum[(j0/B.J)*C.ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB);
688
+ }
689
+ }
690
+ }
691
+
692
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
693
  const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
694
  int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
 
730
  }
731
 
732
  template <int mmq_x, int mmq_y, int nwarps>
733
+ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
734
  const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
735
  const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
736
 
 
755
  }
756
  }
757
 
758
+ template <int mmq_x, int mmq_y, int nwarps>
759
+ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
760
+ const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
761
+ const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
762
+
763
+ GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
764
+
765
+ typedef mma_int_A_I16K8 mma_A;
766
+ typedef mma_int_B_J8K8 mma_B;
767
+ typedef mma_int_C_I16J8 mma_C;
768
+
769
+ const float * x_df = (const float *) x_dm;
770
+ const int * y_qs = (const int *) y + 4;
771
+ const float * y_df = (const float *) y;
772
+
773
+ mma_A A;
774
+ float dA[mma_C::ne/2];
775
+
776
+ const int i0 = threadIdx.y*mma_A::I;
777
+ static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
778
+
779
+ #pragma unroll
780
+ for (int l = 0; l < mma_A::ne; ++l) {
781
+ const int i = i0 + mma_A::get_i(l);
782
+ const int k = k0 + mma_A::get_k(l);
783
+
784
+ A.x[l] = x_ql[i*(WARP_SIZE + 1) + k];
785
+ }
786
+ #pragma unroll
787
+ for (int l = 0; l < mma_C::ne/2; ++l) {
788
+ const int i = i0 + mma_C::get_i(2*l);
789
+
790
+ dA[l] = x_df[i*(WARP_SIZE/QI8_0) + i/QI8_0 + k0/QI8_0];
791
+ }
792
+
793
+ for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
794
+ mma_C C;
795
+ mma_B B;
796
+ float dB[mma_C::ne/2];
797
+
798
+ #pragma unroll
799
+ for (int l = 0; l < mma_B::ne; ++l) {
800
+ const int j = j0 + mma_B::get_j(l);
801
+ const int k = k0 + mma_B::get_k(l);
802
+
803
+ B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
804
+ }
805
+ #pragma unroll
806
+ for (int l = 0; l < mma_C::ne/2; ++l) {
807
+ const int j = j0 + mma_C::get_j(l);
808
+
809
+ dB[l] = y_df[j*MMQ_TILE_Y_K + k0/QI8_1];
810
+ }
811
+
812
+ C.mma_K8(A, B);
813
+
814
+ #pragma unroll
815
+ for (int l = 0; l < mma_C::ne; ++l) {
816
+ sum[(j0/B.J)*C.ne + l] += C.x[l]*dA[l/2]*dB[l%2];
817
+ }
818
+ }
819
+ }
820
+
821
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q2_K(
822
  const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
823
  int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
 
1307
  }
1308
  }
1309
 
1310
+ template<int mmq_x, int mmq_y, int nwarps, bool need_check>
1311
+ static __device__ __forceinline__ void mmq_write_back_dp4a(const float * __restrict__ sum, float * __restrict__ dst, const int & ne0, const int & ne1) {
1312
+ #pragma unroll
1313
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
1314
+ const int j = blockIdx.y*mmq_x + j0 + threadIdx.y;
1315
+
1316
+ if (j >= ne1) {
1317
+ return;
1318
+ }
1319
+
1320
+ #pragma unroll
1321
+ for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
1322
+ const int i = blockIdx.x*mmq_y + i0 + threadIdx.x;
1323
+
1324
+ if (need_check && i >= ne0) {
1325
+ continue;
1326
+ }
1327
+
1328
+ dst[j*ne0 + i] = sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
1329
+ }
1330
+ }
1331
+ }
1332
+
1333
+ template<int mmq_x, int mmq_y, int nwarps, bool need_check>
1334
+ static __device__ __forceinline__ void mmq_write_back_mma(const float * __restrict__ sum, float * __restrict__ dst, const int & ne0, const int & ne1) {
1335
+ typedef mma_int_C_I16J8 mma_C;
1336
+
1337
+ const int i0 = threadIdx.y*mma_C::I;
1338
+ static_assert(nwarps*mma_C::I == mmq_y, "nwarps*mma_C::I != mmq_y");
1339
+
1340
+ #pragma unroll
1341
+ for (int j0 = 0; j0 < mmq_x; j0 += mma_C::J) {
1342
+ #pragma unroll
1343
+ for (int l = 0; l < mma_C::ne; ++l) {
1344
+ const int j = blockIdx.y*mmq_x + j0 + mma_C::get_j(l);
1345
+
1346
+ if (j >= ne1) {
1347
+ continue;
1348
+ }
1349
+
1350
+ const int i = blockIdx.x*mmq_y + i0 + mma_C::get_i(l);
1351
+
1352
+ if (need_check && i >= ne0) {
1353
+ continue;
1354
+ }
1355
+
1356
+ dst[j*ne0 + i] = sum[(j0/mma_C::J)*mma_C::ne + l];
1357
+ }
1358
+ }
1359
+ }
1360
+
1361
  // -------------------------------------------------------------------------------------------------------------------------------------
1362
 
1363
  template <int mmq_x, int mmq_y, int nwarps, bool need_check, ggml_type type>
 
1367
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_0> {
1368
  static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ;
1369
  static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0<mmq_y, nwarps, need_check>;
1370
+ #ifdef INT8_MMA_AVAILABLE
1371
+ static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
1372
+ static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
1373
+ #else
1374
+ static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
1375
+ static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
1376
+ #endif // INT8_MMA_AVAILABLE
1377
  };
1378
 
1379
  template <int mmq_x, int mmq_y, int nwarps, bool need_check>
1380
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_1> {
1381
  static constexpr int vdr = VDR_Q4_1_Q8_1_MMQ;
1382
  static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1<mmq_y, nwarps, need_check>;
1383
+ #ifdef INT8_MMA_AVAILABLE
1384
+ static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
1385
+ static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
1386
+ #else
1387
+ static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
1388
+ static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
1389
+ #endif // INT8_MMA_AVAILABLE
1390
  };
1391
 
1392
  template <int mmq_x, int mmq_y, int nwarps, bool need_check>
1393
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_0> {
1394
  static constexpr int vdr = VDR_Q5_0_Q8_1_MMQ;
1395
  static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0<mmq_y, nwarps, need_check>;
1396
+ #ifdef INT8_MMA_AVAILABLE
1397
+ static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
1398
+ static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
1399
+ #else
1400
+ static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
1401
+ static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
1402
+ #endif // INT8_MMA_AVAILABLE
1403
  };
1404
 
1405
  template <int mmq_x, int mmq_y, int nwarps, bool need_check>
1406
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_1> {
1407
  static constexpr int vdr = VDR_Q5_1_Q8_1_MMQ;
1408
  static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1<mmq_y, nwarps, need_check>;
1409
+ #ifdef INT8_MMA_AVAILABLE
1410
+ static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
1411
+ static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
1412
+ #else
1413
+ static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
1414
+ static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
1415
+ #endif // INT8_MMA_AVAILABLE
1416
  };
1417
 
1418
  template <int mmq_x, int mmq_y, int nwarps, bool need_check>
1419
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q8_0> {
1420
  static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ;
1421
  static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0<mmq_y, nwarps, need_check>;
1422
+ #ifdef INT8_MMA_AVAILABLE
1423
+ static constexpr vec_dot_mmq_t vec_dot = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
1424
+ static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
1425
+ #else
1426
+ static constexpr vec_dot_mmq_t vec_dot = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
1427
+ static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
1428
+ #endif // INT8_MMA_AVAILABLE
1429
  };
1430
 
1431
  template <int mmq_x, int mmq_y, int nwarps, bool need_check>
 
1433
  static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ;
1434
  static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K<mmq_y, nwarps, need_check>;
1435
  static constexpr vec_dot_mmq_t vec_dot = vec_dot_q2_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
1436
+ static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
1437
  };
1438
 
1439
  template <int mmq_x, int mmq_y, int nwarps, bool need_check>
 
1441
  static constexpr int vdr = VDR_Q3_K_Q8_1_MMQ;
1442
  static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K<mmq_y, nwarps, need_check>;
1443
  static constexpr vec_dot_mmq_t vec_dot = vec_dot_q3_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
1444
+ static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
1445
  };
1446
 
1447
  template <int mmq_x, int mmq_y, int nwarps, bool need_check>
 
1449
  static constexpr int vdr = VDR_Q4_K_Q8_1_MMQ;
1450
  static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K<mmq_y, nwarps, need_check>;
1451
  static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
1452
+ static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
1453
  };
1454
 
1455
  template <int mmq_x, int mmq_y, int nwarps, bool need_check>
 
1457
  static constexpr int vdr = VDR_Q5_K_Q8_1_MMQ;
1458
  static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K<mmq_y, nwarps, need_check>;
1459
  static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
1460
+ static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
1461
  };
1462
 
1463
  template <int mmq_x, int mmq_y, int nwarps, bool need_check>
 
1465
  static constexpr int vdr = VDR_Q6_K_Q8_1_MMQ;
1466
  static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K<mmq_y, nwarps, need_check>;
1467
  static constexpr vec_dot_mmq_t vec_dot = vec_dot_q6_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
1468
+ static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
1469
  };
1470
 
1471
  static int mmq_need_sum(const ggml_type type_x) {
 
1522
  constexpr int vdr = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vdr;
1523
  constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::load_tiles;
1524
  constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vec_dot;
1525
+ constexpr mmq_write_back_t write_back = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::write_back;
1526
 
1527
  constexpr tile_x_sizes txs = get_tile_x_sizes_device<mmq_y>(type);
1528
 
 
1542
 
1543
  const int * y = (const int *) yc + blockIdx.y*(mmq_x*sizeof(block_q8_1_mmq)/sizeof(int));
1544
 
1545
+ float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
1546
 
1547
  for (int kb0 = 0; kb0 < blocks_per_row_x; kb0 += blocks_per_warp) {
1548
 
 
1569
  }
1570
  }
1571
 
1572
+ write_back(sum, dst, ne0, ne1);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1573
  }
1574
 
1575
  struct mmq_args {
 
1643
  launch_mul_mat_q<type, 8, 4>(args, stream);
1644
  break;
1645
  case 16:
1646
+ launch_mul_mat_q<type, 16, 4>(args, stream);
1647
  break;
1648
  case 24:
1649
+ launch_mul_mat_q<type, 24, 4>(args, stream);
1650
  break;
1651
  case 32:
1652
  launch_mul_mat_q<type, 32, 8>(args, stream);