Nicolò Scipione commited on
Commit
56f0e48
·
1 Parent(s): 747ad97

sycl: Add reorder to Q6_K mmvq implementation (llama/13885)

Browse files

* Add Reorder to Q6_K mmvq implementation

* Address PR comments: clean up comments

* Remove unused parameter after refactoring q4_k

* Adding inline to function and removing unnecessary reference to int

---------

Signed-off-by: nscipione <[email protected]>

ggml/src/ggml-sycl/convert.cpp CHANGED
@@ -265,6 +265,17 @@ static void dequantize_row_q6_K_sycl(const void *vx, dst_t *y, const int64_t k,
265
  #endif
266
  }
267
 
 
 
 
 
 
 
 
 
 
 
 
268
  template <typename dst_t>
269
  static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int64_t k,
270
  dpct::queue_ptr stream) {
@@ -530,7 +541,11 @@ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) {
530
  case GGML_TYPE_Q5_K:
531
  return dequantize_row_q5_K_sycl;
532
  case GGML_TYPE_Q6_K:
533
- return dequantize_row_q6_K_sycl;
 
 
 
 
534
  case GGML_TYPE_IQ1_S:
535
  return dequantize_row_iq1_s_sycl;
536
  case GGML_TYPE_IQ1_M:
@@ -587,7 +602,11 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) {
587
  case GGML_TYPE_Q5_K:
588
  return dequantize_row_q5_K_sycl;
589
  case GGML_TYPE_Q6_K:
590
- return dequantize_row_q6_K_sycl;
 
 
 
 
591
  case GGML_TYPE_IQ1_S:
592
  return dequantize_row_iq1_s_sycl;
593
  case GGML_TYPE_IQ1_M:
 
265
  #endif
266
  }
267
 
268
+ template <typename dst_t>
269
+ static void dequantize_row_q6_K_sycl_reorder(const void * vx, dst_t * y, const int64_t k, dpct::queue_ptr stream) {
270
+ const int64_t nb = k / QK_K;
271
+
272
+ dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
273
+
274
+ stream->parallel_for(
275
+ sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)),
276
+ [=](sycl::nd_item<3> item_ct1) { dequantize_block_q6_K_reorder(vx, y, item_ct1, nb); });
277
+ }
278
+
279
  template <typename dst_t>
280
  static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int64_t k,
281
  dpct::queue_ptr stream) {
 
541
  case GGML_TYPE_Q5_K:
542
  return dequantize_row_q5_K_sycl;
543
  case GGML_TYPE_Q6_K:
544
+ if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
545
+ return dequantize_row_q6_K_sycl_reorder;
546
+ } else {
547
+ return dequantize_row_q6_K_sycl;
548
+ }
549
  case GGML_TYPE_IQ1_S:
550
  return dequantize_row_iq1_s_sycl;
551
  case GGML_TYPE_IQ1_M:
 
602
  case GGML_TYPE_Q5_K:
603
  return dequantize_row_q5_K_sycl;
604
  case GGML_TYPE_Q6_K:
605
+ if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
606
+ return dequantize_row_q6_K_sycl_reorder;
607
+ } else {
608
+ return dequantize_row_q6_K_sycl;
609
+ }
610
  case GGML_TYPE_IQ1_S:
611
  return dequantize_row_iq1_s_sycl;
612
  case GGML_TYPE_IQ1_M:
ggml/src/ggml-sycl/dequantize.hpp CHANGED
@@ -538,6 +538,38 @@ static void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restri
538
  #endif
539
  }
540
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
541
  template<typename dst_t>
542
  static void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy,
543
  const sycl::nd_item<3> &item_ct1,
 
538
  #endif
539
  }
540
 
541
+ template <typename dst_t>
542
+ static void dequantize_block_q6_K_reorder(const void * __restrict__ vx, dst_t * __restrict__ yy,
543
+ const sycl::nd_item<3> & item_ct1, int64_t n_blocks) {
544
+ const int64_t ib = item_ct1.get_group(2);
545
+
546
+ const int64_t tid = item_ct1.get_local_id(2);
547
+ const int64_t ip = tid / 32; // ip is 0 or 1
548
+ const int64_t il = tid - 32 * ip; // 0...32
549
+ const int64_t is = 8 * ip + il / 16;
550
+
551
+ const uint8_t * base_ptr = static_cast<const uint8_t *>(vx);
552
+ const auto ql_offset = ib * (QK_K / 2);
553
+ const auto qh_offset = (QK_K / 2) * n_blocks + (QK_K / 4) * ib;
554
+ const auto base_scales_offset = (QK_K / 2) * n_blocks + (QK_K / 4) * n_blocks + (QK_K / 16) * ib;
555
+ const auto base_d_offset = ((QK_K / 2) + (QK_K / 4) + (QK_K / 16)) * n_blocks;
556
+ const uint8_t * ql_ptr = base_ptr + ql_offset;
557
+ const uint8_t * qh_ptr = base_ptr + qh_offset;
558
+ const uint8_t * scales_ptr = base_ptr + base_scales_offset;
559
+ const ggml_half * d = (const ggml_half *) (base_ptr + base_d_offset) + ib;
560
+
561
+ dst_t * y = yy + ib * QK_K + 128 * ip + il;
562
+
563
+ const uint8_t * ql = ql_ptr + 64 * ip + il;
564
+ const uint8_t qh = *(qh_ptr + 32 * ip + il);
565
+ const int8_t * sc = reinterpret_cast<const int8_t *>(scales_ptr + is);
566
+
567
+ y[0] = *d * sc[0] * ((int8_t) ((ql[0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32);
568
+ y[32] = *d * sc[2] * ((int8_t) ((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32);
569
+ y[64] = *d * sc[4] * ((int8_t) ((ql[0] >> 4) | (((qh >> 4) & 3) << 4)) - 32);
570
+ y[96] = *d * sc[6] * ((int8_t) ((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32);
571
+ }
572
+
573
  template<typename dst_t>
574
  static void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy,
575
  const sycl::nd_item<3> &item_ct1,
ggml/src/ggml-sycl/ggml-sycl.cpp CHANGED
@@ -354,7 +354,8 @@ ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer,
354
  assert(tensor->view_src->buffer->buft == buffer->buft);
355
  return GGML_STATUS_SUCCESS;
356
  }
357
- if ((tensor->type == GGML_TYPE_Q4_0 || tensor->type == GGML_TYPE_Q4_K) && !g_ggml_sycl_disable_optimize) {
 
358
  ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};
359
  tensor->extra = extra;
360
  ctx->tensor_extras.push_back(extra); //used to release it when destroy ctx.
@@ -2989,6 +2990,7 @@ inline bool ggml_sycl_supports_reorder_mul_mat_sycl(enum ggml_type type) {
2989
  case GGML_TYPE_Q4_0:
2990
  return true;
2991
  case GGML_TYPE_Q4_K:
 
2992
  return !g_ggml_sycl_prioritize_dmmv;
2993
  default:
2994
  return false;
@@ -3008,6 +3010,7 @@ inline bool ggml_sycl_supports_reorder_mmvq(enum ggml_type type) {
3008
  switch (type) {
3009
  case GGML_TYPE_Q4_0:
3010
  case GGML_TYPE_Q4_K:
 
3011
  return true;
3012
  default:
3013
  return false;
@@ -3092,6 +3095,50 @@ static void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, d
3092
  sycl::free(tmp_buf, *stream);
3093
  }
3094
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3095
  static void reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) {
3096
  uint8_t * data_device = (uint8_t *) src0->data;
3097
  size_t ncols = src0->ne[0];
@@ -3105,6 +3152,9 @@ static void reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) {
3105
  case GGML_TYPE_Q4_K:
3106
  reorder_qw_q4_k(data_device, size, 0, stream);
3107
  break;
 
 
 
3108
  default:
3109
  GGML_ABORT("reorder_qw() called with unsupported type");
3110
  break;
 
354
  assert(tensor->view_src->buffer->buft == buffer->buft);
355
  return GGML_STATUS_SUCCESS;
356
  }
357
+ if ((tensor->type == GGML_TYPE_Q4_0 || tensor->type == GGML_TYPE_Q4_K || tensor->type == GGML_TYPE_Q6_K) &&
358
+ !g_ggml_sycl_disable_optimize) {
359
  ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};
360
  tensor->extra = extra;
361
  ctx->tensor_extras.push_back(extra); //used to release it when destroy ctx.
 
2990
  case GGML_TYPE_Q4_0:
2991
  return true;
2992
  case GGML_TYPE_Q4_K:
2993
+ case GGML_TYPE_Q6_K:
2994
  return !g_ggml_sycl_prioritize_dmmv;
2995
  default:
2996
  return false;
 
3010
  switch (type) {
3011
  case GGML_TYPE_Q4_0:
3012
  case GGML_TYPE_Q4_K:
3013
+ case GGML_TYPE_Q6_K:
3014
  return true;
3015
  default:
3016
  return false;
 
3095
  sycl::free(tmp_buf, *stream);
3096
  }
3097
 
3098
+ static void reorder_qw_q6_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) {
3099
+ GGML_ASSERT(size % sizeof(block_q6_K) == 0);
3100
+ GGML_ASSERT(offset % sizeof(block_q6_K) == 0);
3101
+
3102
+ const int nblocks = size / sizeof(block_q6_K);
3103
+
3104
+ auto * tmp_buf = sycl::malloc_shared<uint8_t>(size, *stream);
3105
+ SYCL_CHECK(CHECK_TRY_ERROR((*stream).memcpy(tmp_buf, data_device, size).wait()));
3106
+
3107
+ auto * ql_ptr = data_device;
3108
+ auto * qh_ptr = ql_ptr + (QK_K / 2) * nblocks;
3109
+ auto * scales_ptr = qh_ptr + (QK_K / 4) * nblocks;
3110
+ sycl::half * dm_ptr = (sycl::half *) (scales_ptr + (QK_K / 16) * nblocks);
3111
+
3112
+ stream
3113
+ ->parallel_for(nblocks,
3114
+ [=](auto i) {
3115
+ const block_q6_K * x = (const block_q6_K *) tmp_buf;
3116
+ const int ib = i;
3117
+
3118
+ const uint8_t * ql = x[ib].ql;
3119
+ const uint8_t * qh = x[ib].qh;
3120
+ uint8_t * base_ql_ptr = ql_ptr + (QK_K / 2) * ib;
3121
+ uint8_t * base_qh_ptr = qh_ptr + (QK_K / 4) * ib;
3122
+ uint8_t * base_scales_ptr = scales_ptr + (QK_K / 16) * ib;
3123
+
3124
+ for (int j = 0; j < QK_K / 2; ++j) {
3125
+ base_ql_ptr[j] = ql[j];
3126
+ }
3127
+ for (int j = 0; j < QK_K / 4; ++j) {
3128
+ base_qh_ptr[j] = qh[j];
3129
+ }
3130
+
3131
+ for (int j = 0; j < QK_K / 16; ++j) {
3132
+ base_scales_ptr[j] = x[ib].scales[j];
3133
+ }
3134
+
3135
+ dm_ptr[ib] = x[ib].d;
3136
+ })
3137
+ .wait_and_throw();
3138
+
3139
+ sycl::free(tmp_buf, *stream);
3140
+ }
3141
+
3142
  static void reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) {
3143
  uint8_t * data_device = (uint8_t *) src0->data;
3144
  size_t ncols = src0->ne[0];
 
3152
  case GGML_TYPE_Q4_K:
3153
  reorder_qw_q4_k(data_device, size, 0, stream);
3154
  break;
3155
+ case GGML_TYPE_Q6_K:
3156
+ reorder_qw_q6_k(data_device, size, 0, stream);
3157
+ break;
3158
  default:
3159
  GGML_ABORT("reorder_qw() called with unsupported type");
3160
  break;
ggml/src/ggml-sycl/mmvq.cpp CHANGED
@@ -31,11 +31,10 @@ static void mul_mat_vec_q_reorder(const void * __restrict__ vx, const void * __r
31
 
32
  float partial_sum = 0.0f;
33
  for (int i = sg.get_local_linear_id() / block_elements_per_subgroup; i < blocks_per_row; i += blocks_per_subgroup) {
34
- const int ibx = row * blocks_per_row + i; // x block index
35
- // TODO: Generalize offsets, right now only works for quantizations that don't split high and low bits
36
- const int bx_offset = block_type::get_block_offset(ibx);
37
- const int d_offset = block_type::get_d_offset(nrows, ncols, ibx);
38
 
 
 
39
  // Y block index that aligns with ibx
40
  const int iby = i * block_type::block_to_q8_1_ratio();
41
  const int8_t* q8_1_quant_ptr = (const int8_t*)vy + iby * QK8_1;
@@ -46,7 +45,7 @@ static void mul_mat_vec_q_reorder(const void * __restrict__ vx, const void * __r
46
  // x block quant index when casting the quants to int
47
  const int iqs = elem + block_traits::vdr_mmvq * (sg.get_local_linear_id() % block_elements_per_subgroup);
48
 
49
- partial_sum += reorder_vec_dot_q_sycl()(vx, bx_offset, d_offset, q8_1_quant_ptr, q8_1_ds_ptr, iqs, nblocks);
50
  }
51
  }
52
 
@@ -785,6 +784,24 @@ static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy,
785
  }
786
  }
787
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
788
  static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy,
789
  float *dst, const int ncols,
790
  const int nrows,
@@ -1070,7 +1087,14 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens
1070
  mul_mat_vec_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1071
  break;
1072
  case GGML_TYPE_Q6_K:
1073
- mul_mat_vec_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
 
 
 
 
 
 
 
1074
  break;
1075
  case GGML_TYPE_IQ1_S:
1076
  mul_mat_vec_iq1_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
 
31
 
32
  float partial_sum = 0.0f;
33
  for (int i = sg.get_local_linear_id() / block_elements_per_subgroup; i < blocks_per_row; i += blocks_per_subgroup) {
34
+ const int ibx = row * blocks_per_row + i; // x block index
 
 
 
35
 
36
+ const auto bx_offset = block_type::get_block_offset(ibx, nblocks);
37
+ const auto d_offset = block_type::get_d_offset(nrows, ncols, ibx);
38
  // Y block index that aligns with ibx
39
  const int iby = i * block_type::block_to_q8_1_ratio();
40
  const int8_t* q8_1_quant_ptr = (const int8_t*)vy + iby * QK8_1;
 
45
  // x block quant index when casting the quants to int
46
  const int iqs = elem + block_traits::vdr_mmvq * (sg.get_local_linear_id() % block_elements_per_subgroup);
47
 
48
+ partial_sum += reorder_vec_dot_q_sycl()(vx, bx_offset, d_offset, q8_1_quant_ptr, q8_1_ds_ptr, iqs);
49
  }
50
  }
51
 
 
784
  }
785
  }
786
 
787
+ static void reorder_mul_mat_vec_q6_k_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols,
788
+ const int nrows, dpct::queue_ptr stream) {
789
+ GGML_ASSERT(ncols % QK_K == 0);
790
+ const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y);
791
+ constexpr size_t num_subgroups = 16;
792
+ GGML_ASSERT(block_num_y % num_subgroups == 0);
793
+
794
+ const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
795
+ const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
796
+
797
+ stream->submit([&](sycl::handler & cgh) {
798
+ cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
799
+ [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
800
+ mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q6_K>>(vx, vy, dst, ncols, nrows,
801
+ nd_item);
802
+ });
803
+ });
804
+ }
805
  static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy,
806
  float *dst, const int ncols,
807
  const int nrows,
 
1087
  mul_mat_vec_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1088
  break;
1089
  case GGML_TYPE_Q6_K:
1090
+ if ((ggml_tensor_extra_gpu *) dst->src[0]->extra &&
1091
+ ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
1092
+ GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q6_k_q8_1_sycl\n");
1093
+ reorder_mul_mat_vec_q6_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1094
+ } else {
1095
+ GGML_SYCL_DEBUG("Calling mul_mat_vec_q6_k_q8_1_sycl\n");
1096
+ mul_mat_vec_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1097
+ }
1098
  break;
1099
  case GGML_TYPE_IQ1_S:
1100
  mul_mat_vec_iq1_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
ggml/src/ggml-sycl/quants.hpp CHANGED
@@ -14,12 +14,13 @@
14
  #ifndef GGML_SYCL_QUANTS_HPP
15
  #define GGML_SYCL_QUANTS_HPP
16
 
 
 
17
  #include "ggml-common.h"
18
  #include "ggml.h"
19
 
20
  namespace ggml_sycl_reordered {
21
 
22
-
23
  // The reordered block moves quants (qs) and scales(d) to two
24
  // uniform regions of memory that is contiguous in the same tensor.
25
  // What this means is that instead of having:
@@ -32,7 +33,6 @@ namespace ggml_sycl_reordered {
32
 
33
  template <ggml_type type> struct block_q_t;
34
 
35
-
36
  // qk number of weights / quants in a block
37
  // qr number of weights in a byte (described as 'before dequantization')
38
  // for quantization types that has low and high bits split, qr is calculated with
@@ -47,10 +47,12 @@ template <> struct block_q_t<GGML_TYPE_Q4_0> {
47
  static constexpr uint32_t vdr_mmvq = 2;
48
  };
49
 
50
- static constexpr int get_block_offset(const int block_index) { return block_index * (traits::qk / traits::qr); }
 
 
51
 
52
- static constexpr int get_d_offset(int nrows, int ncols, const int block_index) {
53
- return (ncols / traits::qr * nrows) + block_index * sizeof(ggml_half);
54
  }
55
 
56
  static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; }
@@ -64,20 +66,46 @@ template <> struct block_q_t<GGML_TYPE_Q4_K> {
64
  static constexpr uint32_t vdr_mmvq = 2;
65
  };
66
 
67
- static constexpr int get_block_offset(const int block_index) { return block_index * (traits::qk / traits::qr); }
 
 
68
 
69
- static constexpr int get_d_offset(int nrows, int ncols, const int block_index) {
70
  auto nblocks = (nrows * (ncols / traits::qk));
71
- return (nblocks * QK_K / 2) + (nblocks * K_SCALE_SIZE) + (block_index * sizeof(ggml_half2));
 
72
  }
73
 
74
  static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; }
75
 
76
  constexpr size_t get_total_qs_bytes(int nblocks) { return nblocks * QK_K / 2; }
77
-
78
- constexpr size_t get_dm_offset(int nblocks) { return get_total_qs_bytes(nblocks) + nblocks * K_SCALE_SIZE; }
79
  };
80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  } // namespace ggml_sycl_reordered
82
 
83
  #endif // GGML_SYCL_QUANTS_HPP
 
14
  #ifndef GGML_SYCL_QUANTS_HPP
15
  #define GGML_SYCL_QUANTS_HPP
16
 
17
+ #include <utility>
18
+
19
  #include "ggml-common.h"
20
  #include "ggml.h"
21
 
22
  namespace ggml_sycl_reordered {
23
 
 
24
  // The reordered block moves quants (qs) and scales(d) to two
25
  // uniform regions of memory that is contiguous in the same tensor.
26
  // What this means is that instead of having:
 
33
 
34
  template <ggml_type type> struct block_q_t;
35
 
 
36
  // qk number of weights / quants in a block
37
  // qr number of weights in a byte (described as 'before dequantization')
38
  // for quantization types that has low and high bits split, qr is calculated with
 
47
  static constexpr uint32_t vdr_mmvq = 2;
48
  };
49
 
50
+ static constexpr std::pair<int, int> get_block_offset(const int block_index, const int /* nblocks */) {
51
+ return { block_index * (traits::qk / traits::qr), 0 };
52
+ }
53
 
54
+ static constexpr std::pair<int, int> get_d_offset(int nrows, int ncols, const int block_index) {
55
+ return { (ncols / traits::qr * nrows) + block_index * sizeof(ggml_half), 0 };
56
  }
57
 
58
  static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; }
 
66
  static constexpr uint32_t vdr_mmvq = 2;
67
  };
68
 
69
+ static constexpr std::pair<int, int> get_block_offset(const int block_index, const int /* nblocks */) {
70
+ return { block_index * (traits::qk / traits::qr), 0 };
71
+ }
72
 
73
+ static constexpr std::pair<int, int> get_d_offset(int nrows, int ncols, const int block_index) {
74
  auto nblocks = (nrows * (ncols / traits::qk));
75
+ return { nblocks * (QK_K / 2),
76
+ (nblocks * QK_K / 2) + (nblocks * K_SCALE_SIZE) + (block_index * sizeof(ggml_half2)) };
77
  }
78
 
79
  static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; }
80
 
81
  constexpr size_t get_total_qs_bytes(int nblocks) { return nblocks * QK_K / 2; }
 
 
82
  };
83
 
84
+ template <> struct block_q_t<GGML_TYPE_Q6_K> {
85
+ struct traits {
86
+ static constexpr uint32_t qk = QK_K;
87
+ static constexpr uint32_t qi = QI6_K;
88
+ static constexpr uint32_t qr = QR6_K;
89
+ static constexpr uint32_t vdr_mmvq = 1;
90
+ };
91
+
92
+ static constexpr std::pair<int, int> get_block_offset(const int block_index, const int n_blocks) {
93
+ auto low_bits_index = block_index * (traits::qk / traits::qr);
94
+ // the index of high bits it's after all low bits
95
+ auto high_bits_index = n_blocks * (QK_K / 2) + (block_index * (QK_K / 4));
96
+ return { low_bits_index, high_bits_index };
97
+ }
98
+
99
+ static constexpr std::pair<int, int> get_d_offset(int nrows, int ncols, const int block_index) {
100
+ auto nblocks = (nrows * (ncols / traits::qk));
101
+ auto total_qs_bytes = nblocks * (QK_K / 2) + nblocks * (QK_K / 4);
102
+ auto block_scales = total_qs_bytes + block_index * (QK_K / 16);
103
+ auto sb_scale = total_qs_bytes + nblocks * (QK_K / 16);
104
+ return { block_scales, sb_scale };
105
+ }
106
+
107
+ static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; }
108
+ };
109
  } // namespace ggml_sycl_reordered
110
 
111
  #endif // GGML_SYCL_QUANTS_HPP
ggml/src/ggml-sycl/vecdotq.hpp CHANGED
@@ -284,10 +284,11 @@ template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q4_0> {
284
  return d4 * (sumi * ds8f.x() - (8 * q4_0_traits::vdr_mmvq / q4_0_traits::qi) * ds8f.y());
285
  }
286
 
287
- __dpct_inline__ float operator()(const void * __restrict__ vbq, const int ibx_offset, const int d_offset,
288
- const int8_t* q8_1_quant_ptr, const sycl::half2* q8_1_ds, const int & iqs, int /* nblocks */) {
289
- const uint8_t * bq4_0 = static_cast<const uint8_t *>(vbq) + ibx_offset;
290
- const ggml_half d = *(reinterpret_cast<const ggml_half *>(static_cast<const uint8_t *>(vbq) + d_offset));
 
291
  int v[q4_0_traits::vdr_mmvq];
292
  int u[2 * q4_0_traits::vdr_mmvq];
293
 
@@ -346,15 +347,15 @@ template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q4_K> {
346
  using q4_k_block = ggml_sycl_reordered::block_q_t<GGML_TYPE_Q4_K>;
347
  using q4_k_traits = typename q4_k_block::traits;
348
 
349
- float operator()(const void * __restrict__ vbq, const int ibx_offset, const int d_offset,
350
- const int8_t* q8_1_quant_ptr, const sycl::half2* q8_1_ds, const int & iqs, int nblocks) {
351
- const int ib = ibx_offset / (QK_K / 2);
 
352
 
353
  const uint8_t * base = static_cast<const uint8_t *>(vbq);
354
- const uint8_t * qs = base + ibx_offset;
355
- const int total_qs_bytes = nblocks * (QK_K / 2);
356
- const uint8_t * scs = base + total_qs_bytes + ib * K_SCALE_SIZE;
357
- const ggml_half2 * dms = reinterpret_cast<const ggml_half2 *>(base + d_offset);
358
 
359
  const int bq8_offset = QR4_K * ((iqs / 2) / (QI8_1 / 2));
360
  const int * q4 = (const int *) (qs + 16 * bq8_offset + 4 * ((iqs / 2) % 4));
@@ -395,6 +396,66 @@ template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q4_K> {
395
  }
396
  };
397
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
398
  #define VDR_Q4_0_Q8_1_MMVQ 2
399
  #define VDR_Q4_0_Q8_1_MMQ 4
400
 
 
284
  return d4 * (sumi * ds8f.x() - (8 * q4_0_traits::vdr_mmvq / q4_0_traits::qi) * ds8f.y());
285
  }
286
 
287
+ __dpct_inline__ float operator()(const void * __restrict__ vbq, const std::pair<int, int> ibx_offset,
288
+ const std::pair<int, int> d_offset, const int8_t * q8_1_quant_ptr,
289
+ const sycl::half2 * q8_1_ds, const int & iqs) {
290
+ const uint8_t * bq4_0 = static_cast<const uint8_t *>(vbq) + ibx_offset.first;
291
+ const ggml_half d = *(reinterpret_cast<const ggml_half *>(static_cast<const uint8_t *>(vbq) + d_offset.first));
292
  int v[q4_0_traits::vdr_mmvq];
293
  int u[2 * q4_0_traits::vdr_mmvq];
294
 
 
347
  using q4_k_block = ggml_sycl_reordered::block_q_t<GGML_TYPE_Q4_K>;
348
  using q4_k_traits = typename q4_k_block::traits;
349
 
350
+ __dpct_inline__ float operator()(const void * __restrict__ vbq, const std::pair<int, int> ibx_offset,
351
+ const std::pair<int, int> d_offset, const int8_t * q8_1_quant_ptr,
352
+ const sycl::half2 * q8_1_ds, const int & iqs) {
353
+ const int ib = ibx_offset.first / (QK_K / 2);
354
 
355
  const uint8_t * base = static_cast<const uint8_t *>(vbq);
356
+ const uint8_t * qs = base + ibx_offset.first;
357
+ const uint8_t * scs = base + d_offset.first + ib * K_SCALE_SIZE;
358
+ const ggml_half2 * dms = reinterpret_cast<const ggml_half2 *>(base + d_offset.second);
 
359
 
360
  const int bq8_offset = QR4_K * ((iqs / 2) / (QI8_1 / 2));
361
  const int * q4 = (const int *) (qs + 16 * bq8_offset + 4 * ((iqs / 2) % 4));
 
396
  }
397
  };
398
 
399
+ template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q6_K> {
400
+ static constexpr ggml_type gtype = GGML_TYPE_Q6_K;
401
+
402
+ using q6_k_block = ggml_sycl_reordered::block_q_t<GGML_TYPE_Q6_K>;
403
+ using q6_k_traits = typename q6_k_block::traits;
404
+
405
+ __dpct_inline__ float vec_dot_q6_K_q8_1_impl_mmvq(const int vl, const int vh, const int * __restrict__ u,
406
+ const int8_t * __restrict__ scales, const float d,
407
+ const float * __restrict__ d8) {
408
+ float sumf = 0.0f;
409
+
410
+ #pragma unroll
411
+ for (int i = 0; i < QR6_K; ++i) {
412
+ const int sc = scales[4 * i];
413
+
414
+ const int vil = (vl >> (4 * i)) & 0x0F0F0F0F;
415
+
416
+ const int vih = ((vh >> (4 * i)) << 4) & 0x30303030;
417
+
418
+ const int vi = dpct::vectorized_binary<sycl::char4>((vil | vih), 0x20202020,
419
+ dpct::sub_sat()); // vi = (vil | vih) - 32
420
+
421
+ sumf += d8[i] * (dpct::dp4a(vi, u[i], 0) * sc); // SIMD dot product
422
+ }
423
+
424
+ return d * sumf;
425
+ }
426
+
427
+ __dpct_inline__ float operator()(const void * __restrict__ vbq, const std::pair<int, int> ibx_offset,
428
+ const std::pair<int, int> d_offset, const int8_t * q8_1_quant_ptr, const sycl::half2 * q8_1_ds,
429
+ const int iqs) {
430
+ const int ib = ibx_offset.first / (QK_K / 2);
431
+
432
+ const uint8_t * base = static_cast<const uint8_t *>(vbq);
433
+ const uint8_t * ql = base + ibx_offset.first;
434
+ const uint8_t * qh = base + ibx_offset.second;
435
+ const int8_t * scales = reinterpret_cast<const int8_t *>(base + d_offset.first);
436
+ const ggml_half * d = (const ggml_half *) (base + d_offset.second) + ib;
437
+
438
+ const int bq8_offset = 2 * QR6_K * (iqs / (QI6_K / 2)) + (iqs % (QI6_K / 2)) / (QI6_K / 4);
439
+ const int scale_offset = (QI6_K / 4) * (iqs / (QI6_K / 2)) + (iqs % (QI6_K / 2)) / (QI6_K / 8);
440
+ const int vh_shift = 2 * ((iqs % (QI6_K / 2)) / (QI6_K / 4));
441
+
442
+ const int vl = get_int_from_uint8(ql, iqs);
443
+ const int vh = get_int_from_uint8(qh, (QI6_K / 4) * (iqs / (QI6_K / 2)) + iqs % (QI6_K / 4)) >> vh_shift;
444
+
445
+ const int8_t * scs = scales + scale_offset;
446
+
447
+ int u[QR6_K];
448
+ float d8[QR6_K];
449
+
450
+ #pragma unroll
451
+ for (int i = 0; i < QR6_K; ++i) {
452
+ u[i] = get_int_from_int8_aligned(q8_1_quant_ptr + (bq8_offset + 2 * i) * QK8_1, iqs % QI8_1);
453
+ const sycl::half2 ds_values = *(q8_1_ds + bq8_offset + 2 * i);
454
+ d8[i] = ds_values[0];
455
+ }
456
+ return vec_dot_q6_K_q8_1_impl_mmvq(vl, vh, u, scs, *d, d8);
457
+ }
458
+ };
459
  #define VDR_Q4_0_Q8_1_MMVQ 2
460
  #define VDR_Q4_0_Q8_1_MMQ 4
461