Svetlozar Georgiev commited on
Commit
6ca3a47
·
1 Parent(s): 2008e08

sycl: reordered Q4_K MMVQ (llama/13109)

Browse files
ggml/src/ggml-sycl/convert.cpp CHANGED
@@ -183,6 +183,24 @@ static void dequantize_row_q4_K_sycl(const void *vx, dst_t *y, const int64_t k,
183
  }
184
  }
185
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  template <typename dst_t>
187
  static void dequantize_row_q5_K_sycl(const void *vx, dst_t *y, const int64_t k,
188
  dpct::queue_ptr stream) {
@@ -504,7 +522,11 @@ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) {
504
  case GGML_TYPE_Q3_K:
505
  return dequantize_row_q3_K_sycl;
506
  case GGML_TYPE_Q4_K:
507
- return dequantize_row_q4_K_sycl;
 
 
 
 
508
  case GGML_TYPE_Q5_K:
509
  return dequantize_row_q5_K_sycl;
510
  case GGML_TYPE_Q6_K:
@@ -556,7 +578,12 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) {
556
  case GGML_TYPE_Q3_K:
557
  return dequantize_row_q3_K_sycl;
558
  case GGML_TYPE_Q4_K:
559
- return dequantize_row_q4_K_sycl;
 
 
 
 
 
560
  case GGML_TYPE_Q5_K:
561
  return dequantize_row_q5_K_sycl;
562
  case GGML_TYPE_Q6_K:
 
183
  }
184
  }
185
 
186
+ template <typename dst_t>
187
+ static void dequantize_row_q4_K_sycl_reorder(const void * vx, dst_t * y, const int64_t k, dpct::queue_ptr stream) {
188
+ const int64_t nb = k / QK_K;
189
+ const size_t local_size = 32;
190
+ const size_t global_size = nb * local_size;
191
+
192
+ dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
193
+
194
+ stream->submit([&](sycl::handler & cgh) {
195
+ sycl::local_accessor<uint8_t, 1> scale_local_acc(sycl::range<1>(12), cgh);
196
+
197
+ cgh.parallel_for(sycl::nd_range<1>(sycl::range<1>(global_size), sycl::range<1>(local_size)),
198
+ [=](sycl::nd_item<1> item_ct1) {
199
+ dequantize_block_q4_K_reorder(vx, y, get_pointer(scale_local_acc), item_ct1, nb);
200
+ });
201
+ });
202
+ }
203
+
204
  template <typename dst_t>
205
  static void dequantize_row_q5_K_sycl(const void *vx, dst_t *y, const int64_t k,
206
  dpct::queue_ptr stream) {
 
522
  case GGML_TYPE_Q3_K:
523
  return dequantize_row_q3_K_sycl;
524
  case GGML_TYPE_Q4_K:
525
+ if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
526
+ return dequantize_row_q4_K_sycl_reorder;
527
+ } else {
528
+ return dequantize_row_q4_K_sycl;
529
+ }
530
  case GGML_TYPE_Q5_K:
531
  return dequantize_row_q5_K_sycl;
532
  case GGML_TYPE_Q6_K:
 
578
  case GGML_TYPE_Q3_K:
579
  return dequantize_row_q3_K_sycl;
580
  case GGML_TYPE_Q4_K:
581
+ if (dst->src[0]->extra &&
582
+ ((ggml_tensor_extra_gpu*)dst->src[0]->extra)->optimized_feature.reorder) {
583
+ return dequantize_row_q4_K_sycl_reorder;
584
+ } else {
585
+ return dequantize_row_q4_K_sycl;
586
+ }
587
  case GGML_TYPE_Q5_K:
588
  return dequantize_row_q5_K_sycl;
589
  case GGML_TYPE_Q6_K:
ggml/src/ggml-sycl/dequantize.hpp CHANGED
@@ -357,6 +357,28 @@ static inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8
357
  }
358
  #endif
359
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
360
  template<typename dst_t>
361
  static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy,
362
  uint8_t* scales_local, const sycl::nd_item<3> &item_ct1) {
@@ -365,36 +387,22 @@ static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restri
365
  const int64_t i = item_ct1.get_group(2);
366
 
367
  #if QK_K == 256
368
- // assume 32 threads
369
  const int64_t tid = item_ct1.get_local_id(2);
370
- const int64_t il = tid/8;
371
- const int64_t ir = tid%8;
372
- const int64_t is = 2*il;
373
- const int64_t n = 4;
374
 
375
- dst_t * y = yy + i*QK_K + 64*il + n*ir;
376
 
377
  const sycl::half2 dm = x[i].dm;
378
  const float dall = dm[0];
379
  const float dmin = dm[1];
380
 
381
- if (tid < 12)
382
  scales_local[tid] = x[i].scales[tid];
383
- item_ct1.barrier(sycl::access::fence_space::local_space);
384
-
385
- uint8_t sc, m;
386
- get_scale_min_k4(is + 0, scales_local, sc, m);
387
- const float d1 = dall * sc;
388
- const float m1 = dmin * m;
389
- get_scale_min_k4(is + 1, scales_local, sc, m);
390
- const float d2 = dall * sc;
391
- const float m2 = dmin * m;
392
-
393
- sycl::vec<uint8_t, n> q_vec = vec_aligned_load<uint8_t, n>(x[i].qs + 32*il + n*ir);
394
- for (int l = 0; l < n; ++l) {
395
- y[l + 0] = d1 * (q_vec[l] & 0xF) - m1;
396
- y[l +32] = d2 * (q_vec[l] >> 4) - m2;
397
  }
 
 
 
398
  #else
399
  const int64_t tid = item_ct1.get_local_id(2);
400
  const uint8_t * q = x[i].qs;
@@ -406,6 +414,36 @@ static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restri
406
  #endif
407
  }
408
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
409
  template<typename dst_t>
410
  static void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restrict__ yy,
411
  const sycl::nd_item<3> &item_ct1) {
 
357
  }
358
  #endif
359
 
360
+ template <typename dst_t>
361
+ inline void dequantize_q4_K_common(dst_t * __restrict__ y, const uint8_t * __restrict__ qs_ptr, const float dall,
362
+ const float dmin, uint8_t * __restrict__ scales_local, int il, int ir) {
363
+ const int is = 2 * il;
364
+ constexpr int n = 4;
365
+
366
+ uint8_t sc, m;
367
+ get_scale_min_k4(is + 0, scales_local, sc, m);
368
+ const float d1 = dall * sc;
369
+ const float m1 = dmin * m;
370
+
371
+ get_scale_min_k4(is + 1, scales_local, sc, m);
372
+ const float d2 = dall * sc;
373
+ const float m2 = dmin * m;
374
+
375
+ sycl::vec<uint8_t, n> q_vec = vec_aligned_load<uint8_t, n>(qs_ptr + 32 * il + n * ir);
376
+ for (int l = 0; l < n; ++l) {
377
+ y[l + 0] = d1 * (q_vec[l] & 0xF) - m1;
378
+ y[l + 32] = d2 * (q_vec[l] >> 4) - m2;
379
+ }
380
+ }
381
+
382
  template<typename dst_t>
383
  static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy,
384
  uint8_t* scales_local, const sycl::nd_item<3> &item_ct1) {
 
387
  const int64_t i = item_ct1.get_group(2);
388
 
389
  #if QK_K == 256
 
390
  const int64_t tid = item_ct1.get_local_id(2);
391
+ const int64_t il = tid / 8;
392
+ const int64_t ir = tid % 8;
 
 
393
 
394
+ dst_t * y = yy + i * QK_K + 64 * il + 4 * ir;
395
 
396
  const sycl::half2 dm = x[i].dm;
397
  const float dall = dm[0];
398
  const float dmin = dm[1];
399
 
400
+ if (tid < 12) {
401
  scales_local[tid] = x[i].scales[tid];
 
 
 
 
 
 
 
 
 
 
 
 
 
 
402
  }
403
+
404
+ item_ct1.barrier(sycl::access::fence_space::local_space);
405
+ dequantize_q4_K_common(y, x[i].qs, dall, dmin, scales_local, il, ir);
406
  #else
407
  const int64_t tid = item_ct1.get_local_id(2);
408
  const uint8_t * q = x[i].qs;
 
414
  #endif
415
  }
416
 
417
+ template <typename dst_t>
418
+ static void dequantize_block_q4_K_reorder(const void * __restrict__ vx, dst_t * __restrict__ yy, uint8_t * scales_local,
419
+ const sycl::nd_item<1> & item_ct1, int64_t nb) {
420
+ const int64_t i = item_ct1.get_group(0); // block index
421
+ const int64_t tid = item_ct1.get_local_id(0); // thread index within block
422
+ const int64_t il = tid / 8;
423
+ const int64_t ir = tid % 8;
424
+
425
+ dst_t * y = yy + i * QK_K + 64 * il + 4 * ir;
426
+
427
+ const uint8_t * base = static_cast<const uint8_t *>(vx);
428
+ const size_t qs_offset = i * (QK_K / 2);
429
+ const size_t scales_offset = nb * (QK_K / 2) + i * K_SCALE_SIZE;
430
+ const size_t dm_offset = nb * (QK_K / 2) + nb * K_SCALE_SIZE + i * sizeof(ggml_half2);
431
+
432
+ const uint8_t * qs_ptr = base + qs_offset;
433
+ const uint8_t * scales_ptr = base + scales_offset;
434
+ ggml_half2 dm_values = *reinterpret_cast<const ggml_half2 *>(base + dm_offset);
435
+
436
+ const float dall = dm_values.x();
437
+ const float dmin = dm_values.y();
438
+
439
+ if (tid < 12) {
440
+ scales_local[tid] = scales_ptr[tid];
441
+ }
442
+
443
+ item_ct1.barrier(sycl::access::fence_space::local_space);
444
+ dequantize_q4_K_common(y, qs_ptr, dall, dmin, scales_local, il, ir);
445
+ }
446
+
447
  template<typename dst_t>
448
  static void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restrict__ yy,
449
  const sycl::nd_item<3> &item_ct1) {
ggml/src/ggml-sycl/dmmv.cpp CHANGED
@@ -1129,7 +1129,13 @@ void ggml_sycl_op_dequantize_mul_mat_vec(
1129
  dequantize_mul_mat_vec_q3_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
1130
  break;
1131
  case GGML_TYPE_Q4_K:
1132
- dequantize_mul_mat_vec_q4_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
 
 
 
 
 
 
1133
  break;
1134
  case GGML_TYPE_Q5_K:
1135
  dequantize_mul_mat_vec_q5_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
 
1129
  dequantize_mul_mat_vec_q3_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
1130
  break;
1131
  case GGML_TYPE_Q4_K:
1132
+ if ((ggml_tensor_extra_gpu *) dst->src[0]->extra &&
1133
+ ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
1134
+ // reorder is currently not supported for dmmv
1135
+ GGML_ABORT("Unimplemented dequantize case case for q4_k reorder");
1136
+ } else {
1137
+ dequantize_mul_mat_vec_q4_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
1138
+ }
1139
  break;
1140
  case GGML_TYPE_Q5_K:
1141
  dequantize_mul_mat_vec_q5_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
ggml/src/ggml-sycl/ggml-sycl.cpp CHANGED
@@ -352,7 +352,7 @@ ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer,
352
  assert(tensor->view_src->buffer->buft == buffer->buft);
353
  return GGML_STATUS_SUCCESS;
354
  }
355
- if (tensor->type == GGML_TYPE_Q4_0 && !g_ggml_sycl_disable_optimize) {
356
  ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};
357
  tensor->extra = extra;
358
  ctx->tensor_extras.push_back(extra); //used to release it when destroy ctx.
@@ -2900,6 +2900,8 @@ inline bool ggml_sycl_supports_reorder_mul_mat_sycl(enum ggml_type type) {
2900
  switch (type) {
2901
  case GGML_TYPE_Q4_0:
2902
  return true;
 
 
2903
  default:
2904
  return false;
2905
  }
@@ -2917,6 +2919,7 @@ inline bool ggml_sycl_supports_reorder_dmmv(enum ggml_type type) {
2917
  inline bool ggml_sycl_supports_reorder_mmvq(enum ggml_type type) {
2918
  switch (type) {
2919
  case GGML_TYPE_Q4_0:
 
2920
  return true;
2921
  default:
2922
  return false;
@@ -2942,16 +2945,16 @@ static bool ggml_sycl_supports_dmmv(enum ggml_type type) {
2942
  }
2943
  }
2944
 
2945
- static void reorder_qw(char *data_device, const int ncols, const int nrows,
2946
- size_t size, size_t offset, dpct::queue_ptr stream) {
2947
- auto tmp_buf = sycl::malloc_shared<char>(size, *stream);
2948
  SYCL_CHECK(
2949
  CHECK_TRY_ERROR((*stream).memcpy(tmp_buf, data_device, size)
2950
  .wait()));
2951
  GGML_ASSERT((size % sizeof(block_q4_0) == 0));
2952
  GGML_ASSERT((offset % sizeof(block_q4_0) == 0));
2953
  int offset_blks = offset / sizeof(block_q4_0);
2954
- auto qs_ptr = (uint8_t*)data_device + offset_blks * QK4_0 / 2;
2955
  auto d_ptr = (sycl::half*)(qs_ptr + ncols * nrows / 2) + offset_blks;
2956
 
2957
  stream->parallel_for(
@@ -2965,18 +2968,59 @@ static void reorder_qw(char *data_device, const int ncols, const int nrows,
2965
  *(qs_ptr + ib * QK4_0 / 2 + j) = x[ib].qs[j];
2966
  }
2967
  *(d_ptr + ib) = x[ib].d;
2968
- });
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2969
 
2970
  sycl::free(tmp_buf, *stream);
2971
  }
2972
 
2973
  static void reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) {
2974
- char*data_device = (char*)src0->data;
2975
  size_t ncols = src0->ne[0];
2976
  size_t nrows = src0->ne[1];
2977
  size_t size = ggml_nbytes(src0);
2978
 
2979
- reorder_qw(data_device, ncols, nrows, size, 0, stream);
 
 
 
 
 
 
 
 
 
 
2980
  }
2981
 
2982
  static bool should_reorder_tensor(ggml_backend_sycl_context& ctx, const ggml_tensor * dst) {
@@ -3019,8 +3063,18 @@ static void opt_for_reorder(ggml_backend_sycl_context * ctx, const ggml_tensor *
3019
  extra->optimized_feature.reorder = true; // Used to decode/dequan in next steps and avoid re-reordering
3020
  }
3021
 
3022
- static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3023
 
 
 
 
 
 
 
 
 
 
 
 
3024
  const bool split = ggml_backend_buffer_is_sycl_split(src0->buffer);
3025
  int64_t min_compute_capability = INT_MAX;
3026
 
@@ -3043,13 +3097,9 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
3043
  }
3044
 
3045
  // check data types and tensor shapes for custom matrix multiplication kernels:
3046
- bool use_dequantize_mul_mat_vec = ggml_sycl_supports_dmmv(src0->type)
3047
- && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
3048
- && src0->ne[0] % GGML_SYCL_DMMV_X == 0 && src1->ne[1] == 1;
3049
 
3050
- bool use_mul_mat_vec_q = ggml_is_quantized(src0->type)
3051
- && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
3052
- && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
3053
 
3054
  bool use_mul_mat_q = ggml_sycl_supports_mmq(src0->type)
3055
  && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
 
352
  assert(tensor->view_src->buffer->buft == buffer->buft);
353
  return GGML_STATUS_SUCCESS;
354
  }
355
+ if ((tensor->type == GGML_TYPE_Q4_0 || tensor->type == GGML_TYPE_Q4_K) && !g_ggml_sycl_disable_optimize) {
356
  ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};
357
  tensor->extra = extra;
358
  ctx->tensor_extras.push_back(extra); //used to release it when destroy ctx.
 
2900
  switch (type) {
2901
  case GGML_TYPE_Q4_0:
2902
  return true;
2903
+ case GGML_TYPE_Q4_K:
2904
+ return !g_ggml_sycl_prioritize_dmmv;
2905
  default:
2906
  return false;
2907
  }
 
2919
  inline bool ggml_sycl_supports_reorder_mmvq(enum ggml_type type) {
2920
  switch (type) {
2921
  case GGML_TYPE_Q4_0:
2922
+ case GGML_TYPE_Q4_K:
2923
  return true;
2924
  default:
2925
  return false;
 
2945
  }
2946
  }
2947
 
2948
+ static void reorder_qw_q4_0(uint8_t * data_device, const int ncols, const int nrows, size_t size, size_t offset,
2949
+ dpct::queue_ptr stream) {
2950
+ auto * tmp_buf = sycl::malloc_shared<uint8_t>(size, *stream);
2951
  SYCL_CHECK(
2952
  CHECK_TRY_ERROR((*stream).memcpy(tmp_buf, data_device, size)
2953
  .wait()));
2954
  GGML_ASSERT((size % sizeof(block_q4_0) == 0));
2955
  GGML_ASSERT((offset % sizeof(block_q4_0) == 0));
2956
  int offset_blks = offset / sizeof(block_q4_0);
2957
+ auto qs_ptr = data_device + offset_blks * QK4_0 / 2;
2958
  auto d_ptr = (sycl::half*)(qs_ptr + ncols * nrows / 2) + offset_blks;
2959
 
2960
  stream->parallel_for(
 
2968
  *(qs_ptr + ib * QK4_0 / 2 + j) = x[ib].qs[j];
2969
  }
2970
  *(d_ptr + ib) = x[ib].d;
2971
+ }).wait_and_throw();
2972
+
2973
+ sycl::free(tmp_buf, *stream);
2974
+ }
2975
+
2976
+ static void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) {
2977
+ GGML_ASSERT(size % sizeof(block_q4_K) == 0);
2978
+ GGML_ASSERT(offset % sizeof(block_q4_K) == 0);
2979
+
2980
+ const int nblocks = size / sizeof(block_q4_K);
2981
+
2982
+ auto * tmp_buf = sycl::malloc_shared<uint8_t>(size, *stream);
2983
+ SYCL_CHECK(CHECK_TRY_ERROR((*stream).memcpy(tmp_buf, data_device, size).wait()));
2984
+
2985
+ auto * qs_ptr = data_device;
2986
+ auto * scales_ptr = qs_ptr + QK_K / 2 * nblocks;
2987
+ auto * dm_ptr = (sycl::half2 *) (scales_ptr + K_SCALE_SIZE * nblocks);
2988
+
2989
+ stream->parallel_for(nblocks, [=](auto i) {
2990
+ const block_q4_K * x = (const block_q4_K *) tmp_buf;
2991
+ const int ib = i;
2992
+
2993
+ for (int j = 0; j < QK_K / 2; ++j) {
2994
+ qs_ptr[ib * (QK_K / 2) + j] = x[ib].qs[j];
2995
+ }
2996
+
2997
+ for (int j = 0; j < K_SCALE_SIZE; ++j) {
2998
+ scales_ptr[ib * K_SCALE_SIZE + j] = x[ib].scales[j];
2999
+ }
3000
+
3001
+ dm_ptr[ib] = x[ib].dm;
3002
+ }).wait_and_throw();
3003
 
3004
  sycl::free(tmp_buf, *stream);
3005
  }
3006
 
3007
  static void reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) {
3008
+ uint8_t * data_device = (uint8_t *) src0->data;
3009
  size_t ncols = src0->ne[0];
3010
  size_t nrows = src0->ne[1];
3011
  size_t size = ggml_nbytes(src0);
3012
 
3013
+ switch (src0->type) {
3014
+ case GGML_TYPE_Q4_0:
3015
+ reorder_qw_q4_0(data_device, ncols, nrows, size, 0, stream);
3016
+ break;
3017
+ case GGML_TYPE_Q4_K:
3018
+ reorder_qw_q4_k(data_device, size, 0, stream);
3019
+ break;
3020
+ default:
3021
+ GGML_ABORT("reorder_qw() called with unsupported type");
3022
+ break;
3023
+ }
3024
  }
3025
 
3026
  static bool should_reorder_tensor(ggml_backend_sycl_context& ctx, const ggml_tensor * dst) {
 
3063
  extra->optimized_feature.reorder = true; // Used to decode/dequan in next steps and avoid re-reordering
3064
  }
3065
 
 
3066
 
3067
+ static bool can_use_dequantize_mul_mat_vec(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3068
+ return ggml_sycl_supports_dmmv(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
3069
+ src0->ne[0] % GGML_SYCL_DMMV_X == 0 && src1->ne[1] == 1;
3070
+ }
3071
+
3072
+ static bool can_use_mul_mat_vec_q(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3073
+ return ggml_is_quantized(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
3074
+ src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
3075
+ }
3076
+
3077
+ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3078
  const bool split = ggml_backend_buffer_is_sycl_split(src0->buffer);
3079
  int64_t min_compute_capability = INT_MAX;
3080
 
 
3097
  }
3098
 
3099
  // check data types and tensor shapes for custom matrix multiplication kernels:
3100
+ bool use_dequantize_mul_mat_vec = can_use_dequantize_mul_mat_vec(src0, src1, dst);
 
 
3101
 
3102
+ bool use_mul_mat_vec_q = can_use_mul_mat_vec_q(src0, src1, dst);
 
 
3103
 
3104
  bool use_mul_mat_q = ggml_sycl_supports_mmq(src0->type)
3105
  && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
ggml/src/ggml-sycl/mmvq.cpp CHANGED
@@ -24,6 +24,7 @@ static void mul_mat_vec_q_reorder(const void * __restrict__ vx, const void * __r
24
  const int blocks_per_row = ncols / block_traits::qk;
25
  constexpr int blocks_per_subgroup = ceil_div(block_traits::vdr_mmvq * WARP_SIZE, block_traits::qi);
26
  constexpr int block_elements_per_subgroup = block_traits::qi / block_traits::vdr_mmvq;
 
27
 
28
  static_assert(blocks_per_subgroup > 0);
29
  static_assert(block_elements_per_subgroup > 0);
@@ -45,7 +46,7 @@ static void mul_mat_vec_q_reorder(const void * __restrict__ vx, const void * __r
45
  // x block quant index when casting the quants to int
46
  const int iqs = elem + block_traits::vdr_mmvq * (sg.get_local_linear_id() % block_elements_per_subgroup);
47
 
48
- partial_sum += reorder_vec_dot_q_sycl()(vx, bx_offset, d_offset, &y[iby], iqs);
49
  }
50
  }
51
 
@@ -739,6 +740,27 @@ static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy,
739
  }
740
  }
741
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
742
  static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy,
743
  float *dst, const int ncols,
744
  const int nrows,
@@ -1035,7 +1057,12 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens
1035
  mul_mat_vec_q3_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1036
  break;
1037
  case GGML_TYPE_Q4_K:
1038
- mul_mat_vec_q4_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
 
 
 
 
 
1039
  break;
1040
  case GGML_TYPE_Q5_K:
1041
  mul_mat_vec_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
 
24
  const int blocks_per_row = ncols / block_traits::qk;
25
  constexpr int blocks_per_subgroup = ceil_div(block_traits::vdr_mmvq * WARP_SIZE, block_traits::qi);
26
  constexpr int block_elements_per_subgroup = block_traits::qi / block_traits::vdr_mmvq;
27
+ const int nblocks = nrows * (ncols / block_traits::qk);
28
 
29
  static_assert(blocks_per_subgroup > 0);
30
  static_assert(block_elements_per_subgroup > 0);
 
46
  // x block quant index when casting the quants to int
47
  const int iqs = elem + block_traits::vdr_mmvq * (sg.get_local_linear_id() % block_elements_per_subgroup);
48
 
49
+ partial_sum += reorder_vec_dot_q_sycl()(vx, bx_offset, d_offset, &y[iby], iqs, nblocks);
50
  }
51
  }
52
 
 
740
  }
741
  }
742
 
743
+ static void reorder_mul_mat_vec_q4_k_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols,
744
+ const int nrows, dpct::queue_ptr stream) {
745
+ GGML_ASSERT(ncols % QK_K == 0);
746
+
747
+ const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y);
748
+ constexpr size_t num_subgroups = 16;
749
+ GGML_ASSERT(block_num_y % num_subgroups == 0);
750
+
751
+ const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
752
+ const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
753
+
754
+ stream->submit([&](sycl::handler & cgh) {
755
+ cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
756
+ [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
757
+ mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q4_K>>(vx, vy, dst, ncols,
758
+ nrows, nd_item);
759
+ });
760
+ });
761
+ }
762
+
763
+
764
  static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy,
765
  float *dst, const int ncols,
766
  const int nrows,
 
1057
  mul_mat_vec_q3_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1058
  break;
1059
  case GGML_TYPE_Q4_K:
1060
+ if ((ggml_tensor_extra_gpu *) dst->src[0]->extra &&
1061
+ ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
1062
+ reorder_mul_mat_vec_q4_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1063
+ } else {
1064
+ mul_mat_vec_q4_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1065
+ }
1066
  break;
1067
  case GGML_TYPE_Q5_K:
1068
  mul_mat_vec_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
ggml/src/ggml-sycl/quants.hpp CHANGED
@@ -56,6 +56,28 @@ template <> struct block_q_t<GGML_TYPE_Q4_0> {
56
  static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; }
57
  };
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  } // namespace ggml_sycl_reordered
60
 
61
  #endif // GGML_SYCL_QUANTS_HPP
 
56
  static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; }
57
  };
58
 
59
+ template <> struct block_q_t<GGML_TYPE_Q4_K> {
60
+ struct traits {
61
+ static constexpr uint32_t qk = QK_K;
62
+ static constexpr uint32_t qi = QI4_K;
63
+ static constexpr uint32_t qr = QR4_K;
64
+ static constexpr uint32_t vdr_mmvq = 2;
65
+ };
66
+
67
+ static constexpr int get_block_offset(const int block_index) { return block_index * (traits::qk / traits::qr); }
68
+
69
+ static constexpr int get_d_offset(int nrows, int ncols, const int block_index) {
70
+ auto nblocks = (nrows * (ncols / traits::qk));
71
+ return (nblocks * QK_K / 2) + (nblocks * K_SCALE_SIZE) + (block_index * sizeof(ggml_half2));
72
+ }
73
+
74
+ static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; }
75
+
76
+ constexpr size_t get_total_qs_bytes(int nblocks) { return nblocks * QK_K / 2; }
77
+
78
+ constexpr size_t get_dm_offset(int nblocks) { return get_total_qs_bytes(nblocks) + nblocks * K_SCALE_SIZE; }
79
+ };
80
+
81
  } // namespace ggml_sycl_reordered
82
 
83
  #endif // GGML_SYCL_QUANTS_HPP
ggml/src/ggml-sycl/vecdotq.hpp CHANGED
@@ -285,7 +285,7 @@ template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q4_0> {
285
  }
286
 
287
  __dpct_inline__ float operator()(const void * __restrict__ vbq, const int ibx_offset, const int d_offset,
288
- const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
289
  const uint8_t * bq4_0 = static_cast<const uint8_t *>(vbq) + ibx_offset;
290
  const ggml_half d = *(reinterpret_cast<const ggml_half *>(static_cast<const uint8_t *>(vbq) + d_offset));
291
  int v[q4_0_traits::vdr_mmvq];
@@ -303,6 +303,67 @@ template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q4_0> {
303
  };
304
  };
305
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
306
  #define VDR_Q4_0_Q8_1_MMVQ 2
307
  #define VDR_Q4_0_Q8_1_MMQ 4
308
 
@@ -649,52 +710,17 @@ vec_dot_q3_K_q8_1(const void *__restrict__ vbq,
649
  return vec_dot_q3_K_q8_1_impl_mmvq(vl, vh, u, bq3_K->scales, scale_offset, d, d8);
650
  }
651
 
652
- static __dpct_inline__ float
653
- vec_dot_q4_K_q8_1(const void *__restrict__ vbq,
654
- const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
655
-
656
  #ifndef GGML_QKK_64
657
- const block_q4_K * bq4_K = (const block_q4_K *) vbq;
658
-
659
- int v[2];
660
- int u[2*QR4_K];
661
- float d8[QR4_K];
662
 
663
- // iqs is in 0,2..30. bq8_offset = iqs/4 -> bq8_offset = 0, 2, 4, 6
664
- const int bq8_offset = QR4_K * ((iqs/2) / (QI8_1/2));
665
-
666
- // iqs = 0....3 -> bq8_offset = 0, want q4_offset = 0, 4, 8, 12
667
- // iqs = 4....7 -> bq8_offset = 2, want q4_offset = 32, 36, 40, 44
668
- // iqs = 8...11 -> bq8_offset = 4, want q4_offset = 64, 68, 72, 76
669
- // iqs = 12..15 -> bq8_offset = 6, want q4_offset = 96, 100, 104, 108
670
-
671
- const int * q4 = (const int *)(bq4_K->qs + 16 * bq8_offset + 4 * ((iqs/2)%4));
672
- v[0] = q4[0];
673
- v[1] = q4[4];
674
-
675
- const uint16_t * scales = (const uint16_t *)bq4_K->scales;
676
- uint16_t aux[2];
677
- const int j = bq8_offset/2;
678
- if (j < 2) {
679
- aux[0] = scales[j+0] & 0x3f3f;
680
- aux[1] = scales[j+2] & 0x3f3f;
681
- } else {
682
- aux[0] = ((scales[j+2] >> 0) & 0x0f0f) | ((scales[j-2] & 0xc0c0) >> 2);
683
- aux[1] = ((scales[j+2] >> 4) & 0x0f0f) | ((scales[j-0] & 0xc0c0) >> 2);
684
- }
685
- const uint8_t * sc = (const uint8_t *)aux;
686
- const uint8_t * m = sc + 2;
687
-
688
- for (int i = 0; i < QR4_K; ++i) {
689
- const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
690
- d8[i] = bq8i->ds[0];
691
 
692
- const int * q8 = (const int *)bq8i->qs + ((iqs/2)%4);
693
- u[2*i+0] = q8[0];
694
- u[2*i+1] = q8[4];
695
- }
696
 
697
- return vec_dot_q4_K_q8_1_impl_vmmq(v, u, sc, m, bq4_K->dm, d8);
698
 
699
  #else
700
 
 
285
  }
286
 
287
  __dpct_inline__ float operator()(const void * __restrict__ vbq, const int ibx_offset, const int d_offset,
288
+ const block_q8_1 * __restrict__ bq8_1, const int & iqs, int /* nblocks */) {
289
  const uint8_t * bq4_0 = static_cast<const uint8_t *>(vbq) + ibx_offset;
290
  const ggml_half d = *(reinterpret_cast<const ggml_half *>(static_cast<const uint8_t *>(vbq) + d_offset));
291
  int v[q4_0_traits::vdr_mmvq];
 
303
  };
304
  };
305
 
306
+ static inline float vec_dot_q4_K_q8_1_common(const int * __restrict__ q4, const uint16_t * __restrict__ scales,
307
+ const ggml_half2 & dm, const block_q8_1 * __restrict__ bq8_1,
308
+ const int & iqs) {
309
+ int v[2];
310
+ int u[2 * QR4_K];
311
+ float d8[QR4_K];
312
+
313
+ v[0] = q4[0];
314
+ v[1] = q4[4];
315
+
316
+ uint16_t aux[2];
317
+ const int j = (QR4_K * ((iqs / 2) / (QI8_1 / 2))) / 2;
318
+ if (j < 2) {
319
+ aux[0] = scales[j + 0] & 0x3f3f;
320
+ aux[1] = scales[j + 2] & 0x3f3f;
321
+ } else {
322
+ aux[0] = ((scales[j + 2] >> 0) & 0x0f0f) | ((scales[j - 2] & 0xc0c0) >> 2);
323
+ aux[1] = ((scales[j + 2] >> 4) & 0x0f0f) | ((scales[j - 0] & 0xc0c0) >> 2);
324
+ }
325
+
326
+ const uint8_t * sc = (const uint8_t *) aux;
327
+ const uint8_t * m = sc + 2;
328
+
329
+ const int bq8_offset = QR4_K * ((iqs / 2) / (QI8_1 / 2));
330
+
331
+ for (int i = 0; i < QR4_K; ++i) {
332
+ const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
333
+ d8[i] = bq8i->ds[0];
334
+
335
+ const int * q8 = (const int *) bq8i->qs + ((iqs / 2) % 4);
336
+ u[2 * i + 0] = q8[0];
337
+ u[2 * i + 1] = q8[4];
338
+ }
339
+
340
+ return vec_dot_q4_K_q8_1_impl_vmmq(v, u, sc, m, dm, d8);
341
+ }
342
+
343
+ template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q4_K> {
344
+ static constexpr ggml_type gtype = GGML_TYPE_Q4_K;
345
+
346
+ using q4_k_block = ggml_sycl_reordered::block_q_t<GGML_TYPE_Q4_K>;
347
+ using q4_k_traits = typename q4_k_block::traits;
348
+
349
+ float operator()(const void * __restrict__ vbq, const int ibx_offset, const int d_offset,
350
+ const block_q8_1 * __restrict__ bq8_1, const int & iqs, int nblocks) {
351
+ const int ib = ibx_offset / (QK_K / 2);
352
+
353
+ const uint8_t * base = static_cast<const uint8_t *>(vbq);
354
+ const uint8_t * qs = base + ibx_offset;
355
+ const int total_qs_bytes = nblocks * (QK_K / 2);
356
+ const uint8_t * scs = base + total_qs_bytes + ib * K_SCALE_SIZE;
357
+ const ggml_half2 * dms = reinterpret_cast<const ggml_half2 *>(base + d_offset);
358
+
359
+ const int bq8_offset = QR4_K * ((iqs / 2) / (QI8_1 / 2));
360
+ const int * q4 = (const int *) (qs + 16 * bq8_offset + 4 * ((iqs / 2) % 4));
361
+ const uint16_t * scales = (const uint16_t *) scs;
362
+
363
+ return vec_dot_q4_K_q8_1_common(q4, scales, *dms, bq8_1, iqs);
364
+ }
365
+ };
366
+
367
  #define VDR_Q4_0_Q8_1_MMVQ 2
368
  #define VDR_Q4_0_Q8_1_MMQ 4
369
 
 
710
  return vec_dot_q3_K_q8_1_impl_mmvq(vl, vh, u, bq3_K->scales, scale_offset, d, d8);
711
  }
712
 
713
+ static __dpct_inline__ float vec_dot_q4_K_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1,
714
+ const int & iqs) {
 
 
715
  #ifndef GGML_QKK_64
 
 
 
 
 
716
 
717
+ const block_q4_K * bq4_K = (const block_q4_K *) vbq;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
718
 
719
+ const int bq8_offset = QR4_K * ((iqs / 2) / (QI8_1 / 2));
720
+ const int * q4 = (const int *) (bq4_K->qs + 16 * bq8_offset + 4 * ((iqs / 2) % 4));
721
+ const uint16_t * scales = (const uint16_t *) bq4_K->scales;
 
722
 
723
+ return vec_dot_q4_K_q8_1_common(q4, scales, bq4_K->dm, bq8_1, iqs);
724
 
725
  #else
726