JohannesGaessler commited on
Commit
fcfd59e
·
1 Parent(s): b9b60de

CUDA: revise q8_1 data layout for mul_mat_q (llama/7824)

Browse files
Files changed (5) hide show
  1. ggml-cuda.cu +57 -31
  2. ggml-cuda/mmq.cu +2 -1
  3. ggml-cuda/mmq.cuh +129 -107
  4. ggml-cuda/quantize.cu +77 -10
  5. ggml-cuda/quantize.cuh +16 -1
ggml-cuda.cu CHANGED
@@ -1347,10 +1347,30 @@ static void ggml_cuda_set_peer_access(const int n_tokens, int main_device) {
1347
  GGML_UNUSED(main_device);
1348
  }
1349
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1350
  static void ggml_cuda_op_mul_mat(
1351
  ggml_backend_cuda_context & ctx,
1352
  const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, ggml_cuda_op_mul_mat_t op,
1353
- const bool convert_src1_to_q8_1) {
1354
 
1355
  const int64_t ne00 = src0->ne[0];
1356
  const int64_t ne01 = src0->ne[1];
@@ -1407,7 +1427,9 @@ static void ggml_cuda_op_mul_mat(
1407
  }
1408
 
1409
  struct dev_data {
1410
- ggml_cuda_pool_alloc<char> src0_dd_alloc;
 
 
1411
  ggml_cuda_pool_alloc<float> src1_ddf_alloc;
1412
  ggml_cuda_pool_alloc<char> src1_ddq_alloc;
1413
  ggml_cuda_pool_alloc<float> dst_dd_alloc;
@@ -1426,6 +1448,8 @@ static void ggml_cuda_op_mul_mat(
1426
  int used_devices = 0;
1427
 
1428
  for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
 
 
1429
  // by default, use all rows
1430
  dev[id].row_low = 0;
1431
  dev[id].row_high = ne01;
@@ -1476,11 +1500,15 @@ static void ggml_cuda_op_mul_mat(
1476
  dev[id].src1_ddf = dev[id].src1_ddf_alloc.alloc(ctx.pool(id), ggml_nelements(src1));
1477
  }
1478
 
1479
- if (convert_src1_to_q8_1) {
1480
- dev[id].src1_ddq = dev[id].src1_ddq_alloc.alloc(ctx.pool(id), nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs);
 
 
 
 
1481
 
1482
  if (src1_on_device && src1_is_contiguous) {
1483
- quantize_row_q8_1_cuda(dev[id].src1_ddf, dev[id].src1_ddq, ne10, nrows1, src1_padded_col_size, stream);
1484
  CUDA_CHECK(cudaGetLastError());
1485
  }
1486
  }
@@ -1526,7 +1554,12 @@ static void ggml_cuda_op_mul_mat(
1526
  const int64_t i03 = i0 / ne12;
1527
  const int64_t i02 = i0 % ne12;
1528
 
1529
- const size_t src1_ddq_i_offset = (i0*ne11 + src1_col_0) * src1_padded_col_size*q8_1_ts/q8_1_bs;
 
 
 
 
 
1530
 
1531
  // for split tensors the data begins at i0 == i0_offset_low
1532
  char * src0_dd_i = dev[id].src0_dd + (i0/i02_divisor) * (ne01*ne00*src0_ts)/src0_bs;
@@ -1543,10 +1576,17 @@ static void ggml_cuda_op_mul_mat(
1543
  // copy src0, src1 to device if necessary
1544
  if (src1_is_contiguous) {
1545
  if (id != ctx.device) {
1546
- if (convert_src1_to_q8_1) {
1547
  char * src1_ddq_i_source = dev[ctx.device].src1_ddq + src1_ddq_i_offset;
1548
- CUDA_CHECK(cudaMemcpyPeerAsync(src1_ddq_i, id, src1_ddq_i_source, ctx.device,
1549
- src1_ncols*src1_padded_col_size*q8_1_ts/q8_1_bs, stream));
 
 
 
 
 
 
 
1550
  } else {
1551
  float * src1_ddf_i_source = (float *) src1->data;
1552
  src1_ddf_i_source += (i0*ne11 + src1_col_0) * ne10;
@@ -1561,8 +1601,8 @@ static void ggml_cuda_op_mul_mat(
1561
  GGML_ASSERT(false);
1562
  }
1563
 
1564
- if (convert_src1_to_q8_1 && !src1_is_contiguous) {
1565
- quantize_row_q8_1_cuda(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, src1_padded_col_size, stream);
1566
  CUDA_CHECK(cudaGetLastError());
1567
  }
1568
 
@@ -1587,22 +1627,8 @@ static void ggml_cuda_op_mul_mat(
1587
  float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3);
1588
  GGML_ASSERT(dst->nb[1] == ne0*sizeof(float));
1589
  dhf_dst_i += src1_col_0*ne0 + dev[id].row_low;
1590
- #if !defined(GGML_USE_HIPBLAS)
1591
- // cudaMemcpy2DAsync may fail with copies between vmm pools of different devices
1592
- cudaMemcpy3DPeerParms p = {};
1593
- p.dstDevice = ctx.device;
1594
- p.dstPtr = make_cudaPitchedPtr(dhf_dst_i, ne0*sizeof(float), row_diff, src1_ncols);
1595
- p.srcDevice = id;
1596
- p.srcPtr = make_cudaPitchedPtr(dst_dd_i, row_diff*sizeof(float), row_diff, src1_ncols);
1597
- p.extent = make_cudaExtent(row_diff*sizeof(float), src1_ncols, 1);
1598
- CUDA_CHECK(cudaMemcpy3DPeerAsync(&p, stream));
1599
- #else
1600
- // HIP does not support cudaMemcpy3DPeerAsync or vmm pools
1601
- CUDA_CHECK(cudaMemcpy2DAsync(dhf_dst_i, ne0*sizeof(float),
1602
- dst_dd_i, row_diff*sizeof(float),
1603
- row_diff*sizeof(float), src1_ncols,
1604
- cudaMemcpyDeviceToDevice, stream));
1605
- #endif
1606
  } else {
1607
  float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3);
1608
  GGML_ASSERT(dst->nb[1] == ne0*sizeof(float));
@@ -1941,13 +1967,13 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
1941
  // KQ + KQV multi-batch
1942
  ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);
1943
  } else if (use_dequantize_mul_mat_vec) {
1944
- ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false);
1945
  } else if (use_mul_mat_vec_q) {
1946
- ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, true);
1947
  } else if (use_mul_mat_q) {
1948
- ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_q, true);
1949
  } else {
1950
- ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false);
1951
  }
1952
  }
1953
 
 
1347
  GGML_UNUSED(main_device);
1348
  }
1349
 
1350
+ static cudaError_t ggml_cuda_Memcpy2DPeerAsync(
1351
+ void * dst, int dstDevice, size_t dpitch, void * src, int srcDevice, size_t spitch, size_t width, size_t height, cudaStream_t stream) {
1352
+
1353
+ #if !defined(GGML_USE_HIPBLAS)
1354
+ // cudaMemcpy2DAsync may fail with copies between vmm pools of different devices
1355
+ cudaMemcpy3DPeerParms p = {};
1356
+ p.dstDevice = dstDevice;
1357
+ p.dstPtr = make_cudaPitchedPtr(dst, dpitch, dpitch, height);
1358
+ p.srcDevice = srcDevice;
1359
+ p.srcPtr = make_cudaPitchedPtr(src, spitch, spitch, height);
1360
+ p.extent = make_cudaExtent(width, height, 1);
1361
+ return cudaMemcpy3DPeerAsync(&p, stream);
1362
+ #else
1363
+ // HIP does not support cudaMemcpy3DPeerAsync or vmm pools
1364
+ GGML_UNUSED(dstDevice);
1365
+ GGML_UNUSED(srcDevice);
1366
+ return cudaMemcpy2DAsync(dst, dpitch, src, spitch, width, height, cudaMemcpyDeviceToDevice, stream);
1367
+ #endif // !defined(GGML_USE_HIPBLAS)
1368
+ }
1369
+
1370
  static void ggml_cuda_op_mul_mat(
1371
  ggml_backend_cuda_context & ctx,
1372
  const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, ggml_cuda_op_mul_mat_t op,
1373
+ quantize_cuda_t quantize_src1) {
1374
 
1375
  const int64_t ne00 = src0->ne[0];
1376
  const int64_t ne01 = src0->ne[1];
 
1427
  }
1428
 
1429
  struct dev_data {
1430
+ int cc;
1431
+
1432
+ ggml_cuda_pool_alloc<char> src0_dd_alloc;
1433
  ggml_cuda_pool_alloc<float> src1_ddf_alloc;
1434
  ggml_cuda_pool_alloc<char> src1_ddq_alloc;
1435
  ggml_cuda_pool_alloc<float> dst_dd_alloc;
 
1448
  int used_devices = 0;
1449
 
1450
  for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
1451
+ dev[id].cc = ggml_cuda_info().devices[id].cc;
1452
+
1453
  // by default, use all rows
1454
  dev[id].row_low = 0;
1455
  dev[id].row_high = ne01;
 
1500
  dev[id].src1_ddf = dev[id].src1_ddf_alloc.alloc(ctx.pool(id), ggml_nelements(src1));
1501
  }
1502
 
1503
+ if (quantize_src1) {
1504
+ size_t src_1_ddq_size = nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs;
1505
+ if (quantize_src1 == quantize_mmq_q8_1_cuda) {
1506
+ src_1_ddq_size += get_mmq_x_max_host(dev[id].cc)*sizeof(block_q8_1_mmq);
1507
+ }
1508
+ dev[id].src1_ddq = dev[id].src1_ddq_alloc.alloc(ctx.pool(id), src_1_ddq_size);
1509
 
1510
  if (src1_on_device && src1_is_contiguous) {
1511
+ quantize_src1(dev[id].src1_ddf, dev[id].src1_ddq, ne10, ne11, ne12*ne13, src1_padded_col_size, src0->type, stream);
1512
  CUDA_CHECK(cudaGetLastError());
1513
  }
1514
  }
 
1554
  const int64_t i03 = i0 / ne12;
1555
  const int64_t i02 = i0 % ne12;
1556
 
1557
+ size_t src1_ddq_i_offset = i0*ne11 * src1_padded_col_size*q8_1_ts/q8_1_bs;
1558
+ if (quantize_src1 == quantize_mmq_q8_1_cuda) {
1559
+ src1_ddq_i_offset += src1_col_0 * sizeof(block_q8_1_mmq);
1560
+ } else {
1561
+ src1_ddq_i_offset += src1_col_0 * src1_padded_col_size*q8_1_ts/q8_1_bs;
1562
+ }
1563
 
1564
  // for split tensors the data begins at i0 == i0_offset_low
1565
  char * src0_dd_i = dev[id].src0_dd + (i0/i02_divisor) * (ne01*ne00*src0_ts)/src0_bs;
 
1576
  // copy src0, src1 to device if necessary
1577
  if (src1_is_contiguous) {
1578
  if (id != ctx.device) {
1579
+ if (quantize_src1) {
1580
  char * src1_ddq_i_source = dev[ctx.device].src1_ddq + src1_ddq_i_offset;
1581
+ if (quantize_src1 == quantize_mmq_q8_1_cuda) {
1582
+ const size_t pitch = ne11*sizeof(block_q8_1_mmq);
1583
+ const size_t width = src1_ncols*sizeof(block_q8_1_mmq);
1584
+ const size_t height = src1_padded_col_size/(4*QK8_1);
1585
+ CUDA_CHECK(ggml_cuda_Memcpy2DPeerAsync(src1_ddq_i, id, pitch, src1_ddq_i_source, ctx.device, pitch, width, height, stream));
1586
+ } else {
1587
+ CUDA_CHECK(cudaMemcpyPeerAsync(
1588
+ src1_ddq_i, id, src1_ddq_i_source, ctx.device, src1_ncols*src1_padded_col_size*q8_1_ts/q8_1_bs, stream));
1589
+ }
1590
  } else {
1591
  float * src1_ddf_i_source = (float *) src1->data;
1592
  src1_ddf_i_source += (i0*ne11 + src1_col_0) * ne10;
 
1601
  GGML_ASSERT(false);
1602
  }
1603
 
1604
+ if (quantize_src1 && !src1_is_contiguous) {
1605
+ quantize_src1(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, 1, src1_padded_col_size, src0->type, stream);
1606
  CUDA_CHECK(cudaGetLastError());
1607
  }
1608
 
 
1627
  float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3);
1628
  GGML_ASSERT(dst->nb[1] == ne0*sizeof(float));
1629
  dhf_dst_i += src1_col_0*ne0 + dev[id].row_low;
1630
+ CUDA_CHECK(ggml_cuda_Memcpy2DPeerAsync(
1631
+ dhf_dst_i, ctx.device, ne0*sizeof(float), dst_dd_i, id, row_diff*sizeof(float), row_diff*sizeof(float), src1_ncols, stream));
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1632
  } else {
1633
  float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3);
1634
  GGML_ASSERT(dst->nb[1] == ne0*sizeof(float));
 
1967
  // KQ + KQV multi-batch
1968
  ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);
1969
  } else if (use_dequantize_mul_mat_vec) {
1970
+ ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, nullptr);
1971
  } else if (use_mul_mat_vec_q) {
1972
+ ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, quantize_row_q8_1_cuda);
1973
  } else if (use_mul_mat_q) {
1974
+ ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_q, quantize_mmq_q8_1_cuda);
1975
  } else {
1976
+ ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_cublas, nullptr);
1977
  }
1978
  }
1979
 
ggml-cuda/mmq.cu CHANGED
@@ -11,6 +11,7 @@ void ggml_cuda_op_mul_mat_q(
11
  const int64_t nb01 = src0->nb[1];
12
 
13
  const int64_t ne10 = src1->ne[0];
 
14
  GGML_ASSERT(ne10 % QK8_1 == 0);
15
 
16
  const int64_t ne0 = dst->ne[0];
@@ -25,7 +26,7 @@ void ggml_cuda_op_mul_mat_q(
25
  // nrows_dst == nrows of the matrix that the kernel writes into
26
  const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff;
27
 
28
- const mmq_args args = {src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stride00, src1_padded_row_size, src1_ncols, nrows_dst};
29
 
30
  switch (src0->type) {
31
  case GGML_TYPE_Q4_0:
 
11
  const int64_t nb01 = src0->nb[1];
12
 
13
  const int64_t ne10 = src1->ne[0];
14
+ const int64_t ne11 = src1->ne[1];
15
  GGML_ASSERT(ne10 % QK8_1 == 0);
16
 
17
  const int64_t ne0 = dst->ne[0];
 
26
  // nrows_dst == nrows of the matrix that the kernel writes into
27
  const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff;
28
 
29
+ const mmq_args args = {src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stride00, src1_padded_row_size, src1_ncols, ne11, nrows_dst};
30
 
31
  switch (src0->type) {
32
  case GGML_TYPE_Q4_0:
ggml-cuda/mmq.cuh CHANGED
@@ -1,15 +1,26 @@
 
 
1
  #include "common.cuh"
2
  #include "vecdotq.cuh"
3
 
4
  #include <climits>
5
  #include <cstdint>
6
 
 
 
7
  typedef void (*load_tiles_mmq_t)(
8
  const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
9
  int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride);
10
  typedef void (*vec_dot_mmq_t)(
11
  const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
12
- const int * __restrict__ y_qs, const half2 * __restrict__ y_ms, float * __restrict__ sum, const int & k0);
 
 
 
 
 
 
 
13
 
14
  struct tile_x_sizes {
15
  int ql;
@@ -132,10 +143,14 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
132
  template <int mmq_x, int mmq_y, int nwarps>
133
  static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mul_mat(
134
  const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
135
- const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) {
136
 
137
  GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
138
 
 
 
 
 
139
  #pragma unroll
140
  for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
141
  const int j = j0 + threadIdx.y;
@@ -145,19 +160,18 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mul_mat(
145
  const int i = i0 + threadIdx.x;
146
 
147
  const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2));
148
- const float * x_dmf = (const float *) x_dm;
149
 
150
  int u[2*VDR_Q4_0_Q8_1_MMQ];
151
 
152
  #pragma unroll
153
  for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) {
154
- u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE];
155
- u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_0) % WARP_SIZE];
156
  }
157
 
158
  sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>
159
- (&x_ql[i * (WARP_SIZE + 1) + k0], u, x_dmf[i * (WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0],
160
- y_ds[j * (WARP_SIZE/QI8_1) + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
161
  }
162
  }
163
  }
@@ -203,10 +217,13 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
203
  template <int mmq_x, int mmq_y, int nwarps>
204
  static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mul_mat(
205
  const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
206
- const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) {
207
 
208
  GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
209
 
 
 
 
210
  #pragma unroll
211
  for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
212
  const int j = j0 + threadIdx.y;
@@ -221,13 +238,13 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mul_mat(
221
 
222
  #pragma unroll
223
  for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) {
224
- u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE];
225
- u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_1) % WARP_SIZE];
226
  }
227
 
228
  sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMQ>
229
- (&x_ql[i * (WARP_SIZE + 1) + k0], u, x_dm[i * (WARP_SIZE/QI4_1) + i/QI4_1 + k0/QI4_1],
230
- y_ds[j * (WARP_SIZE/QI8_1) + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
231
  }
232
  }
233
  }
@@ -293,10 +310,14 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
293
  template <int mmq_x, int mmq_y, int nwarps>
294
  static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mul_mat(
295
  const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
296
- const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) {
297
 
298
  GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
299
 
 
 
 
 
300
  #pragma unroll
301
  for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
302
  const int j = j0 + threadIdx.y;
@@ -306,20 +327,18 @@ static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mul_mat(
306
  const int i = i0 + threadIdx.x;
307
 
308
  const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2));
309
- const int index_bx = i * (WARP_SIZE/QI5_0) + i/QI5_0 + k0/QI5_0;
310
- const float * x_dmf = (const float *) x_dm;
311
- const float * y_df = (const float *) y_ds;
312
 
313
  int u[2*VDR_Q5_0_Q8_1_MMQ];
314
 
315
  #pragma unroll
316
  for (int l = 0; l < VDR_Q5_0_Q8_1_MMQ; ++l) {
317
- u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE];
318
- u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_0) % WARP_SIZE];
319
  }
320
 
321
  sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl<float, QR5_0*VDR_Q5_0_Q8_1_MMQ>
322
- (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k0], u, x_dmf[index_bx], y_df[j * (WARP_SIZE/QI8_1) + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
323
  }
324
  }
325
  }
@@ -383,10 +402,13 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
383
  template <int mmq_x, int mmq_y, int nwarps>
384
  static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mul_mat(
385
  const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
386
- const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) {
387
 
388
  GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
389
 
 
 
 
390
  #pragma unroll
391
  for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
392
  const int j = j0 + threadIdx.y;
@@ -396,18 +418,18 @@ static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mul_mat(
396
  const int i = i0 + threadIdx.x;
397
 
398
  const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2));
399
- const int index_bx = i * (WARP_SIZE/QI5_1) + + i/QI5_1 + k0/QI5_1;
400
 
401
  int u[2*VDR_Q5_1_Q8_1_MMQ];
402
 
403
  #pragma unroll
404
  for (int l = 0; l < VDR_Q5_1_Q8_1_MMQ; ++l) {
405
- u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE];
406
- u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_1) % WARP_SIZE];
407
  }
408
 
409
  sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_1_q8_1_impl<QR5_1*VDR_Q5_1_Q8_1_MMQ>
410
- (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k0], u, x_dm[index_bx], y_ds[j * (WARP_SIZE/QI8_1) + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
411
  }
412
  }
413
  }
@@ -455,10 +477,14 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
455
  template <int mmq_x, int mmq_y, int nwarps>
456
  static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mul_mat(
457
  const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
458
- const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) {
459
 
460
  GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
461
 
 
 
 
 
462
  #pragma unroll
463
  for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
464
  const int j = j0 + threadIdx.y;
@@ -467,12 +493,9 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mul_mat(
467
  for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
468
  const int i = i0 + threadIdx.x;
469
 
470
- const float * x_dmf = (const float *) x_dm;
471
- const float * y_df = (const float *) y_ds;
472
-
473
  sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl<float, VDR_Q8_0_Q8_1_MMQ>
474
- (&x_ql[i * (WARP_SIZE + 1) + k0], &y_qs[j * WARP_SIZE + k0], x_dmf[i * (WARP_SIZE/QI8_0) + i/QI8_0 + k0/QI8_0],
475
- y_df[j * (WARP_SIZE/QI8_1) + k0/QI8_1]);
476
  }
477
  }
478
  }
@@ -531,10 +554,13 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
531
  template <int mmq_x, int mmq_y, int nwarps>
532
  static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mul_mat(
533
  const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
534
- const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) {
535
 
536
  GGML_UNUSED(x_qh);
537
 
 
 
 
538
  #pragma unroll
539
  for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
540
  const int j = j0 + threadIdx.y;
@@ -545,11 +571,10 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mul_mat(
545
 
546
  const int kbx = k0 / QI2_K;
547
  const int ky = (k0 % QI2_K) * QR2_K;
548
- const float * y_df = (const float *) y_ds;
549
 
550
  int v[QR2_K*VDR_Q2_K_Q8_1_MMQ];
551
 
552
- const int kqsx = i * (WARP_SIZE + 1) + kbx*QI2_K + (QI2_K/2) * (ky/(2*QI2_K)) + ky % (QI2_K/2);
553
  const int shift = 2 * ((ky % (2*QI2_K)) / (QI2_K/2));
554
 
555
  #pragma unroll
@@ -557,11 +582,11 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mul_mat(
557
  v[l] = (x_ql[kqsx + l] >> shift) & 0x03030303;
558
  }
559
 
560
- const uint8_t * scales = ((const uint8_t *) &x_sc[i * (WARP_SIZE/4) + i/4 + kbx*4]) + ky/4;
561
 
562
- const int index_y = j * WARP_SIZE + (QR2_K*k0) % WARP_SIZE;
563
  sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq(
564
- v, &y_qs[index_y], scales, x_dm[i * (WARP_SIZE/QI2_K) + i/QI2_K + kbx], y_df[index_y/QI8_1]);
 
565
  }
566
  }
567
  }
@@ -646,7 +671,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
646
  template <int mmq_x, int mmq_y, int nwarps>
647
  static __device__ __forceinline__ void vec_dot_q3_K_q8_1_mul_mat(
648
  const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
649
- const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) {
 
 
 
 
650
 
651
  #pragma unroll
652
  for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
@@ -658,8 +687,6 @@ static __device__ __forceinline__ void vec_dot_q3_K_q8_1_mul_mat(
658
 
659
  const int kbx = k0 / QI3_K;
660
  const int ky = (k0 % QI3_K) * QR3_K;
661
- const float * x_dmf = (const float *) x_dm;
662
- const float * y_df = (const float *) y_ds;
663
 
664
  const int8_t * scales = ((const int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4;
665
 
@@ -667,19 +694,19 @@ static __device__ __forceinline__ void vec_dot_q3_K_q8_1_mul_mat(
667
 
668
  #pragma unroll
669
  for (int l = 0; l < QR3_K*VDR_Q3_K_Q8_1_MMQ; ++l) {
670
- const int kqsx = i * (WARP_SIZE + 1) + kbx*QI3_K + (QI3_K/2) * (ky/(2*QI3_K)) + ky % (QI3_K/2);
671
  const int shift = 2 * ((ky % 32) / 8);
672
  const int vll = (x_ql[kqsx + l] >> shift) & 0x03030303;
673
 
674
- const int vh = x_qh[i * (WARP_SIZE/2) + i/2 + kbx * (QI3_K/2) + (ky+l)%8] >> ((ky+l) / 8);
675
  const int vlh = (vh << 2) & 0x04040404;
676
 
677
  v[l] = __vsubss4(vll, vlh);
678
  }
679
 
680
- const int index_y = j * WARP_SIZE + (k0*QR3_K) % WARP_SIZE;
681
  sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q3_K_q8_1_impl_mmq(
682
- v, &y_qs[index_y], scales, x_dmf[i * (WARP_SIZE/QI3_K) + i/QI3_K + kbx], y_df[index_y/QI8_1]);
 
683
  }
684
  }
685
  }
@@ -746,10 +773,13 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
746
  template <int mmq_x, int mmq_y, int nwarps>
747
  static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mul_mat(
748
  const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
749
- const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) {
750
 
751
  GGML_UNUSED(x_qh);
752
 
 
 
 
753
  #pragma unroll
754
  for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
755
  const int j = j0 + threadIdx.y;
@@ -760,9 +790,9 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mul_mat(
760
 
761
  const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2*((k0 % 16) / 8);
762
 
763
- const int index_y = j * WARP_SIZE + (QR4_K*k0) % WARP_SIZE;
764
  sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_K_q8_1_impl_mmq(
765
- &x_ql[i * (WARP_SIZE + 1) + k0], &y_qs[index_y], sc, sc+8, x_dm[i * (WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[index_y/QI8_1]);
 
766
  }
767
  }
768
  }
@@ -842,10 +872,13 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
842
  template <int mmq_x, int mmq_y, int nwarps>
843
  static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mul_mat(
844
  const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
845
- const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) {
846
 
847
  GGML_UNUSED(x_qh);
848
 
 
 
 
849
  #pragma unroll
850
  for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
851
  const int j = j0 + threadIdx.y;
@@ -856,10 +889,9 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mul_mat(
856
 
857
  const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2 * ((k0 % 16) / 8);
858
 
859
- const int index_x = i * (QR5_K*WARP_SIZE + 1) + QR5_K*k0;
860
- const int index_y = j * WARP_SIZE + (QR5_K*k0) % WARP_SIZE;
861
  sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q5_K_q8_1_impl_mmq(
862
- &x_ql[index_x], &y_qs[index_y], sc, sc+8, x_dm[i * (WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[index_y/QI8_1]);
 
863
  }
864
  }
865
  }
@@ -932,10 +964,14 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
932
  template <int mmq_x, int mmq_y, int nwarps>
933
  static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mul_mat(
934
  const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
935
- const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) {
936
 
937
  GGML_UNUSED(x_qh);
938
 
 
 
 
 
939
  #pragma unroll
940
  for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
941
  const int j = j0 + threadIdx.y;
@@ -944,15 +980,11 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mul_mat(
944
  for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
945
  const int i = i0 + threadIdx.x;
946
 
947
- const float * x_dmf = (const float *) x_dm;
948
- const float * y_df = (const float *) y_ds;
949
-
950
  const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/8]);
951
 
952
- const int index_x = i * (QR6_K*WARP_SIZE + 1) + QR6_K*k0;
953
- const int index_y = j * WARP_SIZE + (QR6_K*k0) % WARP_SIZE;
954
  sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q6_K_q8_1_impl_mmq(
955
- &x_ql[index_x], &y_qs[index_y], sc, x_dmf[i * (WARP_SIZE/QI6_K) + i/QI6_K], &y_df[index_y/QI8_1]);
 
956
  }
957
  }
958
  }
@@ -964,7 +996,6 @@ struct mmq_type_traits;
964
 
965
  template <int mmq_x, int mmq_y, int nwarps, bool need_check>
966
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_0> {
967
- static constexpr bool need_sum = true;
968
  static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ;
969
  static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0<mmq_y, nwarps, need_check>;
970
  static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_0_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
@@ -972,7 +1003,6 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_0> {
972
 
973
  template <int mmq_x, int mmq_y, int nwarps, bool need_check>
974
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_1> {
975
- static constexpr bool need_sum = true;
976
  static constexpr int vdr = VDR_Q4_1_Q8_1_MMQ;
977
  static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1<mmq_y, nwarps, need_check>;
978
  static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_1_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
@@ -980,7 +1010,6 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_1> {
980
 
981
  template <int mmq_x, int mmq_y, int nwarps, bool need_check>
982
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_0> {
983
- static constexpr bool need_sum = false;
984
  static constexpr int vdr = VDR_Q5_0_Q8_1_MMQ;
985
  static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0<mmq_y, nwarps, need_check>;
986
  static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_0_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
@@ -988,7 +1017,6 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_0> {
988
 
989
  template <int mmq_x, int mmq_y, int nwarps, bool need_check>
990
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_1> {
991
- static constexpr bool need_sum = true;
992
  static constexpr int vdr = VDR_Q5_1_Q8_1_MMQ;
993
  static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1<mmq_y, nwarps, need_check>;
994
  static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_1_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
@@ -996,7 +1024,6 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_1> {
996
 
997
  template <int mmq_x, int mmq_y, int nwarps, bool need_check>
998
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q8_0> {
999
- static constexpr bool need_sum = false;
1000
  static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ;
1001
  static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0<mmq_y, nwarps, need_check>;
1002
  static constexpr vec_dot_mmq_t vec_dot = vec_dot_q8_0_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
@@ -1004,7 +1031,6 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q8_0> {
1004
 
1005
  template <int mmq_x, int mmq_y, int nwarps, bool need_check>
1006
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q2_K> {
1007
- static constexpr bool need_sum = false;
1008
  static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ;
1009
  static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K<mmq_y, nwarps, need_check>;
1010
  static constexpr vec_dot_mmq_t vec_dot = vec_dot_q2_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
@@ -1012,7 +1038,6 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q2_K> {
1012
 
1013
  template <int mmq_x, int mmq_y, int nwarps, bool need_check>
1014
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q3_K> {
1015
- static constexpr bool need_sum = false;
1016
  static constexpr int vdr = VDR_Q3_K_Q8_1_MMQ;
1017
  static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K<mmq_y, nwarps, need_check>;
1018
  static constexpr vec_dot_mmq_t vec_dot = vec_dot_q3_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
@@ -1020,7 +1045,6 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q3_K> {
1020
 
1021
  template <int mmq_x, int mmq_y, int nwarps, bool need_check>
1022
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_K> {
1023
- static constexpr bool need_sum = true;
1024
  static constexpr int vdr = VDR_Q4_K_Q8_1_MMQ;
1025
  static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K<mmq_y, nwarps, need_check>;
1026
  static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
@@ -1028,7 +1052,6 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_K> {
1028
 
1029
  template <int mmq_x, int mmq_y, int nwarps, bool need_check>
1030
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_K> {
1031
- static constexpr bool need_sum = true;
1032
  static constexpr int vdr = VDR_Q5_K_Q8_1_MMQ;
1033
  static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K<mmq_y, nwarps, need_check>;
1034
  static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
@@ -1036,12 +1059,36 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_K> {
1036
 
1037
  template <int mmq_x, int mmq_y, int nwarps, bool need_check>
1038
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q6_K> {
1039
- static constexpr bool need_sum = false;
1040
  static constexpr int vdr = VDR_Q6_K_Q8_1_MMQ;
1041
  static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K<mmq_y, nwarps, need_check>;
1042
  static constexpr vec_dot_mmq_t vec_dot = vec_dot_q6_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
1043
  };
1044
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1045
  template <ggml_type type, int mmq_x, int nwarps, bool need_check>
1046
  #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
1047
  #if defined(RDNA3) || defined(RDNA2)
@@ -1056,7 +1103,7 @@ template <ggml_type type, int mmq_x, int nwarps, bool need_check>
1056
  #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
1057
  static __global__ void mul_mat_q(
1058
  const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst,
1059
- const int ne00, const int ne01, const int stride00, const int ne10, const int ne11, const int ne0) {
1060
 
1061
  // Skip unused template specializations for faster compilation:
1062
  if (mmq_x > get_mmq_x_max_device()) {
@@ -1068,7 +1115,6 @@ static __global__ void mul_mat_q(
1068
  constexpr int qr = ggml_cuda_type_traits<type>::qr;
1069
  constexpr int qi = ggml_cuda_type_traits<type>::qi;
1070
  constexpr int mmq_y = get_mmq_y_device(mmq_x);
1071
- constexpr bool need_sum = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::need_sum;
1072
  constexpr int vdr = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vdr;
1073
  constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::load_tiles;
1074
  constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vec_dot;
@@ -1080,62 +1126,38 @@ static __global__ void mul_mat_q(
1080
  half2 * tile_x_dm = (half2 *) (tile_x_ql + txs.ql);
1081
  int * tile_x_qh = (int *) (tile_x_dm + txs.dm);
1082
  int * tile_x_sc = (int *) (tile_x_qh + txs.qh);
1083
- int * tile_y_qs = (int *) (tile_x_sc + txs.sc); // [mmq_x * WARP_SIZE]
1084
- half2 * tile_y_ds = (half2 *) (tile_y_qs + mmq_x*WARP_SIZE); // [mmq_x * WARP_SIZE/QI8_1];
1085
-
1086
- const block_q8_1 * y = (const block_q8_1 *) yc;
1087
 
1088
  const int blocks_per_row_x = ne00 / qk;
1089
- const int blocks_per_col_y = ne10 / QK8_1;
1090
  const int blocks_per_warp = WARP_SIZE / qi;
1091
 
1092
  const int & ne1 = ne11;
1093
 
1094
  const int tile_x_max_i = ne01 - blockIdx.x*mmq_y - 1;
1095
 
 
 
1096
  float sum[(mmq_x/nwarps) * (mmq_y/WARP_SIZE)] = {0.0f};
1097
 
1098
  for (int kb0 = 0; kb0 < blocks_per_row_x; kb0 += blocks_per_warp) {
1099
 
1100
- load_tiles(x, tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, stride00*blockIdx.x*mmq_y + kb0, tile_x_max_i, stride00);
1101
 
1102
  #pragma unroll
1103
  for (int kr = 0; kr < qr; ++kr) {
1104
- const int kqs = kr*WARP_SIZE + threadIdx.x;
1105
- const int kbxd = kqs / QI8_1;
1106
-
1107
  #pragma unroll
1108
- for (int i0 = 0; i0 < mmq_x; i0 += nwarps) {
1109
- const int i = min(blockIdx.y*mmq_x + threadIdx.y + i0, ne11-1); // to prevent out-of-bounds memory accesses
1110
-
1111
- const block_q8_1 * by0 = &y[i*blocks_per_col_y + kb0 * (qk/QK8_1) + kbxd];
1112
 
1113
- const int index_y = (i0 + threadIdx.y) * WARP_SIZE + kqs % WARP_SIZE;
1114
- tile_y_qs[index_y] = get_int_from_int8_aligned(by0->qs, threadIdx.x % QI8_1);
1115
- }
1116
-
1117
- #pragma unroll
1118
- for (int ids0 = 0; ids0 < mmq_x; ids0 += nwarps * QI8_1) {
1119
- const int ids = (ids0 + threadIdx.y * QI8_1 + threadIdx.x / (WARP_SIZE/QI8_1)) % mmq_x;
1120
- const int kby = threadIdx.x % (WARP_SIZE/QI8_1);
1121
- const int i_y_eff = min(blockIdx.y*mmq_x + ids, ne11-1);
1122
-
1123
- // if the sum is not needed it's faster to transform the scale to f32 ahead of time
1124
- const half2 * dsi_src = &y[i_y_eff*blocks_per_col_y + kb0 * (qk/QK8_1) + kr*(WARP_SIZE/QI8_1) + kby].ds;
1125
- half2 * dsi_dst = &tile_y_ds[ids * (WARP_SIZE/QI8_1) + kby];
1126
- if (need_sum) {
1127
- *dsi_dst = *dsi_src;
1128
- } else {
1129
- float * dfi_dst = (float *) dsi_dst;
1130
- *dfi_dst = __low2float(*dsi_src);
1131
- }
1132
  }
1133
 
1134
  __syncthreads();
1135
 
1136
  // #pragma unroll // unrolling this loop causes too much register pressure
1137
  for (int k0 = kr*WARP_SIZE/qr; k0 < (kr+1)*WARP_SIZE/qr; k0 += vdr) {
1138
- vec_dot(tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, tile_y_qs, tile_y_ds, sum, k0);
1139
  }
1140
 
1141
  __syncthreads();
@@ -1165,8 +1187,8 @@ static __global__ void mul_mat_q(
1165
 
1166
  struct mmq_args {
1167
  const char * x; const char * y; float * dst;
1168
- int64_t ne00; int64_t ne01; int64_t stride00;
1169
- int64_t ne10; int64_t ne11;
1170
  int64_t ne0;
1171
  };
1172
 
@@ -1184,7 +1206,7 @@ static void launch_mul_mat_q(const mmq_args & args, cudaStream_t stream) {
1184
  const tile_x_sizes txs = get_tile_x_sizes_host(type, mmq_y);
1185
  const int shmem_x = txs.ql*sizeof(int) + txs.dm*sizeof(half2) + txs.qh*sizeof(int) + txs.sc*sizeof(int);
1186
  const int shmem_y = mmq_x*WARP_SIZE*sizeof(int) + mmq_x*(WARP_SIZE/QI8_1)*sizeof(half2);
1187
- const int shmem = shmem_x + shmem_y;
1188
 
1189
  #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
1190
  static bool shmem_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
@@ -1198,11 +1220,11 @@ static void launch_mul_mat_q(const mmq_args & args, cudaStream_t stream) {
1198
  if (args.ne01 % mmq_y == 0) {
1199
  const bool need_check = false;
1200
  mul_mat_q<type, mmq_x, nwarps, need_check><<<block_nums, block_dims, shmem, stream>>>
1201
- (args.x, args.y, args.dst, args.ne00, args.ne01, args.stride00, args.ne10, args.ne11, args.ne0);
1202
  } else {
1203
  const bool need_check = true;
1204
  mul_mat_q<type, mmq_x, nwarps, need_check><<<block_nums, block_dims, shmem, stream>>>
1205
- (args.x, args.y, args.dst, args.ne00, args.ne01, args.stride00, args.ne10, args.ne11, args.ne0);
1206
  }
1207
  }
1208
 
 
1
+ #pragma once
2
+
3
  #include "common.cuh"
4
  #include "vecdotq.cuh"
5
 
6
  #include <climits>
7
  #include <cstdint>
8
 
9
+ #define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1)
10
+
11
  typedef void (*load_tiles_mmq_t)(
12
  const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
13
  int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride);
14
  typedef void (*vec_dot_mmq_t)(
15
  const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
16
+ const int * __restrict__ y, float * __restrict__ sum, const int & k0);
17
+
18
+ struct block_q8_1_mmq {
19
+ half2 ds[4];
20
+ int8_t qs[4*QK8_1];
21
+ };
22
+ static_assert(sizeof(block_q8_1_mmq) == 4*QK8_1 + 4*sizeof(half2), "Unexpected block_q8_1_mmq size");
23
+ static_assert(sizeof(block_q8_1_mmq) == 4*sizeof(block_q8_1), "Unexpected block_q8_1_mmq size");
24
 
25
  struct tile_x_sizes {
26
  int ql;
 
143
  template <int mmq_x, int mmq_y, int nwarps>
144
  static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mul_mat(
145
  const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
146
+ const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
147
 
148
  GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
149
 
150
+ const float * x_dmf = (const float *) x_dm;
151
+ const int * y_qs = (const int *) y + 4;
152
+ const half2 * y_ds = (const half2 *) y;
153
+
154
  #pragma unroll
155
  for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
156
  const int j = j0 + threadIdx.y;
 
160
  const int i = i0 + threadIdx.x;
161
 
162
  const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2));
 
163
 
164
  int u[2*VDR_Q4_0_Q8_1_MMQ];
165
 
166
  #pragma unroll
167
  for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) {
168
+ u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l) % WARP_SIZE];
169
+ u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l + QI4_0) % WARP_SIZE];
170
  }
171
 
172
  sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>
173
+ (&x_ql[i*(WARP_SIZE + 1) + k0], u, x_dmf[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0],
174
+ y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
175
  }
176
  }
177
  }
 
217
  template <int mmq_x, int mmq_y, int nwarps>
218
  static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mul_mat(
219
  const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
220
+ const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
221
 
222
  GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
223
 
224
+ const int * y_qs = (const int *) y + 4;
225
+ const half2 * y_ds = (const half2 *) y;
226
+
227
  #pragma unroll
228
  for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
229
  const int j = j0 + threadIdx.y;
 
238
 
239
  #pragma unroll
240
  for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) {
241
+ u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l) % WARP_SIZE];
242
+ u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l + QI4_1) % WARP_SIZE];
243
  }
244
 
245
  sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMQ>
246
+ (&x_ql[i*(WARP_SIZE + 1) + k0], u, x_dm[i*(WARP_SIZE/QI4_1) + i/QI4_1 + k0/QI4_1],
247
+ y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
248
  }
249
  }
250
  }
 
310
  template <int mmq_x, int mmq_y, int nwarps>
311
  static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mul_mat(
312
  const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
313
+ const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
314
 
315
  GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
316
 
317
+ const float * x_dmf = (const float *) x_dm;
318
+ const int * y_qs = (const int *) y + 4;
319
+ const float * y_df = (const float *) y;
320
+
321
  #pragma unroll
322
  for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
323
  const int j = j0 + threadIdx.y;
 
327
  const int i = i0 + threadIdx.x;
328
 
329
  const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2));
330
+ const int index_bx = i*(WARP_SIZE/QI5_0) + i/QI5_0 + k0/QI5_0;
 
 
331
 
332
  int u[2*VDR_Q5_0_Q8_1_MMQ];
333
 
334
  #pragma unroll
335
  for (int l = 0; l < VDR_Q5_0_Q8_1_MMQ; ++l) {
336
+ u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l) % WARP_SIZE];
337
+ u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l + QI5_0) % WARP_SIZE];
338
  }
339
 
340
  sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl<float, QR5_0*VDR_Q5_0_Q8_1_MMQ>
341
+ (&x_ql[i*(2*WARP_SIZE + 1) + 2*k0], u, x_dmf[index_bx], y_df[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
342
  }
343
  }
344
  }
 
402
  template <int mmq_x, int mmq_y, int nwarps>
403
  static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mul_mat(
404
  const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
405
+ const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
406
 
407
  GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
408
 
409
+ const int * y_qs = (const int *) y + 4;
410
+ const half2 * y_ds = (const half2 *) y;
411
+
412
  #pragma unroll
413
  for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
414
  const int j = j0 + threadIdx.y;
 
418
  const int i = i0 + threadIdx.x;
419
 
420
  const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2));
421
+ const int index_bx = i*(WARP_SIZE/QI5_1) + i/QI5_1 + k0/QI5_1;
422
 
423
  int u[2*VDR_Q5_1_Q8_1_MMQ];
424
 
425
  #pragma unroll
426
  for (int l = 0; l < VDR_Q5_1_Q8_1_MMQ; ++l) {
427
+ u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l) % WARP_SIZE];
428
+ u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l + QI5_1) % WARP_SIZE];
429
  }
430
 
431
  sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_1_q8_1_impl<QR5_1*VDR_Q5_1_Q8_1_MMQ>
432
+ (&x_ql[i*(2*WARP_SIZE + 1) + 2*k0], u, x_dm[index_bx], y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
433
  }
434
  }
435
  }
 
477
  template <int mmq_x, int mmq_y, int nwarps>
478
  static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mul_mat(
479
  const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
480
+ const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
481
 
482
  GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
483
 
484
+ const float * x_dmf = (const float *) x_dm;
485
+ const int * y_qs = (const int *) y + 4;
486
+ const float * y_df = (const float *) y;
487
+
488
  #pragma unroll
489
  for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
490
  const int j = j0 + threadIdx.y;
 
493
  for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
494
  const int i = i0 + threadIdx.x;
495
 
 
 
 
496
  sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl<float, VDR_Q8_0_Q8_1_MMQ>
497
+ (&x_ql[i*(WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k0], x_dmf[i*(WARP_SIZE/QI8_0) + i/QI8_0 + k0/QI8_0],
498
+ y_df[j*MMQ_TILE_Y_K + k0/QI8_1]);
499
  }
500
  }
501
  }
 
554
  template <int mmq_x, int mmq_y, int nwarps>
555
  static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mul_mat(
556
  const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
557
+ const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
558
 
559
  GGML_UNUSED(x_qh);
560
 
561
+ const int * y_qs = (const int *) y + 4;
562
+ const float * y_df = (const float *) y;
563
+
564
  #pragma unroll
565
  for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
566
  const int j = j0 + threadIdx.y;
 
571
 
572
  const int kbx = k0 / QI2_K;
573
  const int ky = (k0 % QI2_K) * QR2_K;
 
574
 
575
  int v[QR2_K*VDR_Q2_K_Q8_1_MMQ];
576
 
577
+ const int kqsx = i*(WARP_SIZE + 1) + kbx*QI2_K + (QI2_K/2) * (ky/(2*QI2_K)) + ky % (QI2_K/2);
578
  const int shift = 2 * ((ky % (2*QI2_K)) / (QI2_K/2));
579
 
580
  #pragma unroll
 
582
  v[l] = (x_ql[kqsx + l] >> shift) & 0x03030303;
583
  }
584
 
585
+ const uint8_t * scales = ((const uint8_t *) &x_sc[i*(WARP_SIZE/4) + i/4 + kbx*4]) + ky/4;
586
 
 
587
  sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq(
588
+ v, &y_qs[j*MMQ_TILE_Y_K + (QR2_K*k0) % WARP_SIZE], scales,
589
+ x_dm[i*(WARP_SIZE/QI2_K) + i/QI2_K + kbx], y_df[j*MMQ_TILE_Y_K + ((QR2_K*k0) % WARP_SIZE)/QI8_1]);
590
  }
591
  }
592
  }
 
671
  template <int mmq_x, int mmq_y, int nwarps>
672
  static __device__ __forceinline__ void vec_dot_q3_K_q8_1_mul_mat(
673
  const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
674
+ const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
675
+
676
+ const float * x_dmf = (const float *) x_dm;
677
+ const int * y_qs = (const int *) y + 4;
678
+ const float * y_df = (const float *) y;
679
 
680
  #pragma unroll
681
  for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
 
687
 
688
  const int kbx = k0 / QI3_K;
689
  const int ky = (k0 % QI3_K) * QR3_K;
 
 
690
 
691
  const int8_t * scales = ((const int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4;
692
 
 
694
 
695
  #pragma unroll
696
  for (int l = 0; l < QR3_K*VDR_Q3_K_Q8_1_MMQ; ++l) {
697
+ const int kqsx = i*(WARP_SIZE + 1) + kbx*QI3_K + (QI3_K/2) * (ky/(2*QI3_K)) + ky % (QI3_K/2);
698
  const int shift = 2 * ((ky % 32) / 8);
699
  const int vll = (x_ql[kqsx + l] >> shift) & 0x03030303;
700
 
701
+ const int vh = x_qh[i*(WARP_SIZE/2) + i/2 + kbx * (QI3_K/2) + (ky+l)%8] >> ((ky+l) / 8);
702
  const int vlh = (vh << 2) & 0x04040404;
703
 
704
  v[l] = __vsubss4(vll, vlh);
705
  }
706
 
 
707
  sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q3_K_q8_1_impl_mmq(
708
+ v, &y_qs[j*MMQ_TILE_Y_K + (k0*QR3_K) % WARP_SIZE], scales,
709
+ x_dmf[i*(WARP_SIZE/QI3_K) + i/QI3_K + kbx], y_df[j*MMQ_TILE_Y_K + ((k0*QR3_K) % WARP_SIZE)/QI8_1]);
710
  }
711
  }
712
  }
 
773
  template <int mmq_x, int mmq_y, int nwarps>
774
  static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mul_mat(
775
  const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
776
+ const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
777
 
778
  GGML_UNUSED(x_qh);
779
 
780
+ const int * y_qs = (const int *) y + 4;
781
+ const half2 * y_ds = (const half2 *) y;
782
+
783
  #pragma unroll
784
  for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
785
  const int j = j0 + threadIdx.y;
 
790
 
791
  const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2*((k0 % 16) / 8);
792
 
 
793
  sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_K_q8_1_impl_mmq(
794
+ &x_ql[i*(WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + (QR4_K*k0) % WARP_SIZE], sc, sc+8,
795
+ x_dm[i*(WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[j*MMQ_TILE_Y_K + ((QR4_K*k0) % WARP_SIZE)/QI8_1]);
796
  }
797
  }
798
  }
 
872
  template <int mmq_x, int mmq_y, int nwarps>
873
  static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mul_mat(
874
  const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
875
+ const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
876
 
877
  GGML_UNUSED(x_qh);
878
 
879
+ const int * y_qs = (const int *) y + 4;
880
+ const half2 * y_ds = (const half2 *) y;
881
+
882
  #pragma unroll
883
  for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
884
  const int j = j0 + threadIdx.y;
 
889
 
890
  const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2 * ((k0 % 16) / 8);
891
 
 
 
892
  sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q5_K_q8_1_impl_mmq(
893
+ &x_ql[i*(QR5_K*WARP_SIZE + 1) + QR5_K*k0], &y_qs[j*MMQ_TILE_Y_K + (QR5_K*k0) % WARP_SIZE], sc, sc+8,
894
+ x_dm[i*(WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[j*MMQ_TILE_Y_K + ((QR5_K*k0) % WARP_SIZE)/QI8_1]);
895
  }
896
  }
897
  }
 
964
  template <int mmq_x, int mmq_y, int nwarps>
965
  static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mul_mat(
966
  const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
967
+ const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
968
 
969
  GGML_UNUSED(x_qh);
970
 
971
+ const float * x_dmf = (const float *) x_dm;
972
+ const int * y_qs = (const int *) y + 4;
973
+ const float * y_df = (const float *) y;
974
+
975
  #pragma unroll
976
  for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
977
  const int j = j0 + threadIdx.y;
 
980
  for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
981
  const int i = i0 + threadIdx.x;
982
 
 
 
 
983
  const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/8]);
984
 
 
 
985
  sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q6_K_q8_1_impl_mmq(
986
+ &x_ql[i*(QR6_K*WARP_SIZE + 1) + QR6_K*k0], &y_qs[j*MMQ_TILE_Y_K + (QR6_K*k0) % WARP_SIZE], sc,
987
+ x_dmf[i*(WARP_SIZE/QI6_K) + i/QI6_K], &y_df[j*MMQ_TILE_Y_K + ((QR6_K*k0) % WARP_SIZE)/QI8_1]);
988
  }
989
  }
990
  }
 
996
 
997
  template <int mmq_x, int mmq_y, int nwarps, bool need_check>
998
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_0> {
 
999
  static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ;
1000
  static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0<mmq_y, nwarps, need_check>;
1001
  static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_0_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
 
1003
 
1004
  template <int mmq_x, int mmq_y, int nwarps, bool need_check>
1005
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_1> {
 
1006
  static constexpr int vdr = VDR_Q4_1_Q8_1_MMQ;
1007
  static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1<mmq_y, nwarps, need_check>;
1008
  static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_1_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
 
1010
 
1011
  template <int mmq_x, int mmq_y, int nwarps, bool need_check>
1012
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_0> {
 
1013
  static constexpr int vdr = VDR_Q5_0_Q8_1_MMQ;
1014
  static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0<mmq_y, nwarps, need_check>;
1015
  static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_0_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
 
1017
 
1018
  template <int mmq_x, int mmq_y, int nwarps, bool need_check>
1019
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_1> {
 
1020
  static constexpr int vdr = VDR_Q5_1_Q8_1_MMQ;
1021
  static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1<mmq_y, nwarps, need_check>;
1022
  static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_1_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
 
1024
 
1025
  template <int mmq_x, int mmq_y, int nwarps, bool need_check>
1026
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q8_0> {
 
1027
  static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ;
1028
  static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0<mmq_y, nwarps, need_check>;
1029
  static constexpr vec_dot_mmq_t vec_dot = vec_dot_q8_0_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
 
1031
 
1032
  template <int mmq_x, int mmq_y, int nwarps, bool need_check>
1033
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q2_K> {
 
1034
  static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ;
1035
  static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K<mmq_y, nwarps, need_check>;
1036
  static constexpr vec_dot_mmq_t vec_dot = vec_dot_q2_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
 
1038
 
1039
  template <int mmq_x, int mmq_y, int nwarps, bool need_check>
1040
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q3_K> {
 
1041
  static constexpr int vdr = VDR_Q3_K_Q8_1_MMQ;
1042
  static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K<mmq_y, nwarps, need_check>;
1043
  static constexpr vec_dot_mmq_t vec_dot = vec_dot_q3_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
 
1045
 
1046
  template <int mmq_x, int mmq_y, int nwarps, bool need_check>
1047
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_K> {
 
1048
  static constexpr int vdr = VDR_Q4_K_Q8_1_MMQ;
1049
  static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K<mmq_y, nwarps, need_check>;
1050
  static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
 
1052
 
1053
  template <int mmq_x, int mmq_y, int nwarps, bool need_check>
1054
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_K> {
 
1055
  static constexpr int vdr = VDR_Q5_K_Q8_1_MMQ;
1056
  static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K<mmq_y, nwarps, need_check>;
1057
  static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
 
1059
 
1060
  template <int mmq_x, int mmq_y, int nwarps, bool need_check>
1061
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q6_K> {
 
1062
  static constexpr int vdr = VDR_Q6_K_Q8_1_MMQ;
1063
  static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K<mmq_y, nwarps, need_check>;
1064
  static constexpr vec_dot_mmq_t vec_dot = vec_dot_q6_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
1065
  };
1066
 
1067
+ static int mmq_need_sum(const ggml_type type_x) {
1068
+ switch (type_x) {
1069
+ case GGML_TYPE_Q4_0:
1070
+ case GGML_TYPE_Q4_1:
1071
+ return true;
1072
+ case GGML_TYPE_Q5_0:
1073
+ return false;
1074
+ case GGML_TYPE_Q5_1:
1075
+ return true;
1076
+ case GGML_TYPE_Q8_0:
1077
+ case GGML_TYPE_Q2_K:
1078
+ case GGML_TYPE_Q3_K:
1079
+ return false;
1080
+ case GGML_TYPE_Q4_K:
1081
+ case GGML_TYPE_Q5_K:
1082
+ return true;
1083
+ case GGML_TYPE_Q6_K:
1084
+ return false;
1085
+ default:
1086
+ GGML_ASSERT(false);
1087
+ break;
1088
+ }
1089
+ return false;
1090
+ }
1091
+
1092
  template <ggml_type type, int mmq_x, int nwarps, bool need_check>
1093
  #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
1094
  #if defined(RDNA3) || defined(RDNA2)
 
1103
  #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
1104
  static __global__ void mul_mat_q(
1105
  const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst,
1106
+ const int ne00, const int ne01, const int stride01, const int ne10, const int ne11, const int stride11, const int ne0) {
1107
 
1108
  // Skip unused template specializations for faster compilation:
1109
  if (mmq_x > get_mmq_x_max_device()) {
 
1115
  constexpr int qr = ggml_cuda_type_traits<type>::qr;
1116
  constexpr int qi = ggml_cuda_type_traits<type>::qi;
1117
  constexpr int mmq_y = get_mmq_y_device(mmq_x);
 
1118
  constexpr int vdr = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vdr;
1119
  constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::load_tiles;
1120
  constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vec_dot;
 
1126
  half2 * tile_x_dm = (half2 *) (tile_x_ql + txs.ql);
1127
  int * tile_x_qh = (int *) (tile_x_dm + txs.dm);
1128
  int * tile_x_sc = (int *) (tile_x_qh + txs.qh);
1129
+ int * tile_y = (int *) (tile_x_sc + txs.sc); // [mmq_x * (WARP_SIZE + WARP_SIZE/QI8_1)]
 
 
 
1130
 
1131
  const int blocks_per_row_x = ne00 / qk;
 
1132
  const int blocks_per_warp = WARP_SIZE / qi;
1133
 
1134
  const int & ne1 = ne11;
1135
 
1136
  const int tile_x_max_i = ne01 - blockIdx.x*mmq_y - 1;
1137
 
1138
+ const int * y = (const int *) yc + blockIdx.y*(mmq_x*sizeof(block_q8_1_mmq)/sizeof(int));
1139
+
1140
  float sum[(mmq_x/nwarps) * (mmq_y/WARP_SIZE)] = {0.0f};
1141
 
1142
  for (int kb0 = 0; kb0 < blocks_per_row_x; kb0 += blocks_per_warp) {
1143
 
1144
+ load_tiles(x, tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, stride01*blockIdx.x*mmq_y + kb0, tile_x_max_i, stride01);
1145
 
1146
  #pragma unroll
1147
  for (int kr = 0; kr < qr; ++kr) {
1148
+ const int * by0 = y + stride11*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + kr*sizeof(block_q8_1_mmq)/sizeof(int));
 
 
1149
  #pragma unroll
1150
+ for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*WARP_SIZE) {
1151
+ int l = l0 + threadIdx.y*WARP_SIZE + threadIdx.x;
 
 
1152
 
1153
+ tile_y[l] = by0[l];
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1154
  }
1155
 
1156
  __syncthreads();
1157
 
1158
  // #pragma unroll // unrolling this loop causes too much register pressure
1159
  for (int k0 = kr*WARP_SIZE/qr; k0 < (kr+1)*WARP_SIZE/qr; k0 += vdr) {
1160
+ vec_dot(tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, tile_y, sum, k0);
1161
  }
1162
 
1163
  __syncthreads();
 
1187
 
1188
  struct mmq_args {
1189
  const char * x; const char * y; float * dst;
1190
+ int64_t ne00; int64_t ne01; int64_t stride01;
1191
+ int64_t ne10; int64_t ne11; int64_t stride11;
1192
  int64_t ne0;
1193
  };
1194
 
 
1206
  const tile_x_sizes txs = get_tile_x_sizes_host(type, mmq_y);
1207
  const int shmem_x = txs.ql*sizeof(int) + txs.dm*sizeof(half2) + txs.qh*sizeof(int) + txs.sc*sizeof(int);
1208
  const int shmem_y = mmq_x*WARP_SIZE*sizeof(int) + mmq_x*(WARP_SIZE/QI8_1)*sizeof(half2);
1209
+ const int shmem = shmem_x + GGML_PAD(shmem_y, nwarps*WARP_SIZE*sizeof(int));
1210
 
1211
  #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
1212
  static bool shmem_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
 
1220
  if (args.ne01 % mmq_y == 0) {
1221
  const bool need_check = false;
1222
  mul_mat_q<type, mmq_x, nwarps, need_check><<<block_nums, block_dims, shmem, stream>>>
1223
+ (args.x, args.y, args.dst, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
1224
  } else {
1225
  const bool need_check = true;
1226
  mul_mat_q<type, mmq_x, nwarps, need_check><<<block_nums, block_dims, shmem, stream>>>
1227
+ (args.x, args.y, args.dst, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
1228
  }
1229
  }
1230
 
ggml-cuda/quantize.cu CHANGED
@@ -1,22 +1,23 @@
1
  #include "quantize.cuh"
 
2
 
3
- static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int64_t kx, const int64_t kx_padded) {
4
- const int64_t ix = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
5
 
6
- if (ix >= kx_padded) {
7
  return;
8
  }
9
 
10
- const int64_t iy = (int64_t)blockDim.y*blockIdx.y + threadIdx.y;
11
 
12
- const int64_t i_padded = (int64_t)iy*kx_padded + ix;
13
 
14
  block_q8_1 * y = (block_q8_1 *) vy;
15
 
16
  const int64_t ib = i_padded / QK8_1; // block index
17
  const int64_t iqs = i_padded % QK8_1; // quant index
18
 
19
- const float xi = ix < kx ? x[iy*kx + ix] : 0.0f;
20
  float amax = fabsf(xi);
21
  float sum = xi;
22
 
@@ -36,10 +37,76 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest
36
  reinterpret_cast<half&>(y[ib].ds.y) = sum;
37
  }
38
 
39
- void quantize_row_q8_1_cuda(const float * x, void * vy, const int64_t kx, const int64_t ky, const int64_t kx_padded, cudaStream_t stream) {
40
- const int64_t block_num_x = (kx_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
41
- const dim3 num_blocks(block_num_x, ky, 1);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1);
43
- quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, kx, kx_padded);
 
 
44
  }
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  #include "quantize.cuh"
2
+ #include <cstdint>
3
 
4
+ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int64_t kx, const int64_t kx0_padded) {
5
+ const int64_t ix0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
6
 
7
+ if (ix0 >= kx0_padded) {
8
  return;
9
  }
10
 
11
+ const int64_t ix1 = blockIdx.y;
12
 
13
+ const int64_t i_padded = ix1*kx0_padded + ix0;
14
 
15
  block_q8_1 * y = (block_q8_1 *) vy;
16
 
17
  const int64_t ib = i_padded / QK8_1; // block index
18
  const int64_t iqs = i_padded % QK8_1; // quant index
19
 
20
+ const float xi = ix0 < kx ? x[ix1*kx + ix0] : 0.0f;
21
  float amax = fabsf(xi);
22
  float sum = xi;
23
 
 
37
  reinterpret_cast<half&>(y[ib].ds.y) = sum;
38
  }
39
 
40
+ template <bool need_sum>
41
+ static __global__ void quantize_mmq_q8_1(
42
+ const float * __restrict__ x, void * __restrict__ vy, const int64_t kx0, const int64_t kx1, const int64_t kx0_padded) {
43
+
44
+ const int64_t ix0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
45
+
46
+ if (ix0 >= kx0_padded) {
47
+ return;
48
+ }
49
+
50
+ const int64_t ix1 = kx1*blockIdx.z + blockIdx.y;
51
+
52
+ block_q8_1_mmq * y = (block_q8_1_mmq *) vy;
53
+
54
+ const int64_t ib0 = blockIdx.z*(gridDim.y*gridDim.x*blockDim.x/(4*QK8_1)); // first block of channel
55
+ const int64_t ib = ib0 + (ix0 / (4*QK8_1))*kx1 + blockIdx.y; // block index in channel
56
+ const int64_t iqs = ix0 % (4*QK8_1); // quant index in block
57
+
58
+ const float xi = ix0 < kx0 ? x[ix1*kx0 + ix0] : 0.0f;
59
+ float amax = fabsf(xi);
60
+
61
+ amax = warp_reduce_max(amax);
62
+
63
+ float sum;
64
+ if (need_sum) {
65
+ sum = warp_reduce_sum(xi);
66
+ }
67
+
68
+ const float d = amax / 127;
69
+ const int8_t q = amax == 0.0f ? 0 : roundf(xi / d);
70
+
71
+ y[ib].qs[iqs] = q;
72
+
73
+ if (iqs % QK8_1 != 0) {
74
+ return;
75
+ }
76
+
77
+ if (need_sum) {
78
+ y[ib].ds[iqs/QK8_1] = make_half2(d, sum);
79
+ } else {
80
+ ((float *) y[ib].ds)[iqs/QK8_1] = d;
81
+ }
82
+ }
83
+
84
+ void quantize_row_q8_1_cuda(
85
+ const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels,
86
+ const int64_t kx0_padded, const ggml_type type_x, cudaStream_t stream) {
87
+
88
+ GGML_ASSERT(kx0_padded % QK8_1 == 0);
89
+
90
+ const int64_t block_num_x = (kx0_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
91
+ const dim3 num_blocks(block_num_x, kx1*channels, 1);
92
  const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1);
93
+ quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx0_padded);
94
+
95
+ GGML_UNUSED(type_x);
96
  }
97
 
98
+ void quantize_mmq_q8_1_cuda(
99
+ const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels,
100
+ const int64_t kx0_padded, const ggml_type type_x, cudaStream_t stream) {
101
+
102
+ GGML_ASSERT(kx0_padded % (4*QK8_1) == 0);
103
+
104
+ const int64_t block_num_x = (kx0_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
105
+ const dim3 num_blocks(block_num_x, kx1, channels);
106
+ const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1);
107
+ if (mmq_need_sum(type_x)) {
108
+ quantize_mmq_q8_1<true><<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
109
+ } else {
110
+ quantize_mmq_q8_1<false><<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
111
+ }
112
+ }
ggml-cuda/quantize.cuh CHANGED
@@ -1,5 +1,20 @@
 
 
1
  #include "common.cuh"
 
 
 
2
 
3
  #define CUDA_QUANTIZE_BLOCK_SIZE 256
4
 
5
- void quantize_row_q8_1_cuda(const float * x, void * vy, const int64_t kx, const int64_t ky, const int64_t kx_padded, cudaStream_t stream);
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
  #include "common.cuh"
4
+ #include "mmq.cuh"
5
+
6
+ #include <cstdint>
7
 
8
  #define CUDA_QUANTIZE_BLOCK_SIZE 256
9
 
10
+ typedef void (*quantize_cuda_t)(
11
+ const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded,
12
+ const ggml_type type_x, cudaStream_t stream);
13
+
14
+ void quantize_row_q8_1_cuda(
15
+ const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded,
16
+ const ggml_type type_x, cudaStream_t stream);
17
+
18
+ void quantize_mmq_q8_1_cuda(
19
+ const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded,
20
+ const ggml_type type_x, cudaStream_t stream);