Atharva Dubey Alberto Cabrera Pérez commited on
Commit
c4e62cd
·
1 Parent(s): 73547ad

sycl: quantize and reorder the input to q8_1 when reorder is enabled (llama/13826)

Browse files

* [WIP]: fuse q8 quantization and reorder

* wip2: fuse q8 quantization and reorder

* working q8 reorder commit

* restored common.hpp

* remove debug prints

* remove unnecessary headers and remove trailing whitespace

* Update ggml/src/ggml-sycl/ggml-sycl.cpp

Co-authored-by: Alberto Cabrera Pérez <[email protected]>

---------

Co-authored-by: Alberto Cabrera Pérez <[email protected]>

ggml/src/ggml-sycl/ggml-sycl.cpp CHANGED
@@ -1434,6 +1434,59 @@ static void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy,
1434
  reinterpret_cast<sycl::half &>(y[ib].ds.y()) = sum;
1435
  }
1436
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1437
  static void mul_mat_p021_f16_f32(
1438
  const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst,
1439
  const int ncols_x, const int nrows_x, const int nchannels_x, const int nchannels_y,
@@ -1718,23 +1771,30 @@ static void pool2d_nchw_kernel(
1718
  o_ptr[cur_oh * ow + cur_ow] = res;
1719
  }
1720
 
1721
- static void quantize_row_q8_1_sycl(const float *x, void *vy, const int kx,
1722
- const int ky, const int kx_padded,
1723
- queue_ptr stream) {
1724
- const int block_num_x = (kx_padded + SYCL_QUANTIZE_BLOCK_SIZE - 1) / SYCL_QUANTIZE_BLOCK_SIZE;
1725
- const sycl::range<3> num_blocks(1, ky, block_num_x);
1726
- int constexpr QUANT_BLOCK_TILE = QK8_1 / WARP_SIZE;
1727
- static_assert(QK8_1 % WARP_SIZE == 0);
1728
- const sycl::range<3> block_size(1, 1, SYCL_QUANTIZE_BLOCK_SIZE / QUANT_BLOCK_TILE);
1729
- {
1730
- dpct::has_capability_or_fail(stream->get_device(),
1731
- {sycl::aspect::fp16});
 
 
 
 
 
 
 
1732
 
1733
- stream->parallel_for(
1734
- sycl::nd_range<3>(num_blocks * block_size, block_size),
1735
- [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
1736
- quantize_q8_1<QUANT_BLOCK_TILE>(x, vy, kx, kx_padded, item_ct1);
1737
- });
1738
  }
1739
  }
1740
 
@@ -2446,9 +2506,10 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
2446
  dev[i].src1_ddq = dev[i].src1_ddq_alloc.alloc(ctx.pool(i), nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs);
2447
 
2448
  if (src1_on_device && src1_is_contiguous) {
 
2449
  scope_op_debug_print scope_dbg_print(__func__, "/quantize_row_q8_1_sycl", dst,
2450
  /*num_src=*/2, " : converting src1 to Q8_1");
2451
- quantize_row_q8_1_sycl(dev[i].src1_ddf, dev[i].src1_ddq, ne10, nrows1, src1_padded_col_size, stream);
2452
  /*
2453
  DPCT1010:90: SYCL uses exceptions to report errors and does not
2454
  use the error codes. The call was replaced with 0. You need to
@@ -2554,7 +2615,7 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
2554
  if (convert_src1_to_q8_1 && !src1_is_contiguous) {
2555
  scope_op_debug_print scope_dbg_print(__func__, "/quantize_row_q8_1_sycl", dst,
2556
  /*num_src=*/2, " : converting src1 to Q8_1");
2557
- quantize_row_q8_1_sycl(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, src1_padded_col_size, stream);
2558
  /*
2559
  DPCT1010:92: SYCL uses exceptions to report errors and does
2560
  not use the error codes. The call was replaced with 0. You
 
1434
  reinterpret_cast<sycl::half &>(y[ib].ds.y()) = sum;
1435
  }
1436
 
1437
+ template <int ElementsPerWI>
1438
+ static __dpct_inline__ void quantize_and_reorder_q8_1(const float * __restrict__ x, void * reordered_q8_tensor,
1439
+ const int kx, const int kx_padded, const sycl::nd_item<1> & it) {
1440
+ /*
1441
+ Quantizes and reorders the resultant q8 tensor in a per row fashion
1442
+ Each sub-group calculates one quant block. i.e. QK8_1 quant values and the d and sum values
1443
+ */
1444
+
1445
+ auto subgroup_id = it.get_group(0);
1446
+ auto wi_id = it.get_local_id(0);
1447
+
1448
+ const int num_blocks_per_row = kx / QK8_1;
1449
+ auto row = subgroup_id / num_blocks_per_row;
1450
+ auto col = subgroup_id % num_blocks_per_row;
1451
+
1452
+ auto row_offset = row * (kx_padded / QK8_1) * sizeof(block_q8_1);
1453
+ auto col_offset = QK8_1 * col + wi_id * ElementsPerWI;
1454
+
1455
+ auto quant_ptr = (int8_t *) ((char *) reordered_q8_tensor + row_offset + col_offset);
1456
+ auto ds_ptr = (sycl::half2 *) ((char *) reordered_q8_tensor + row_offset + kx + col * sizeof(sycl::half2));
1457
+
1458
+ sycl::vec<float, ElementsPerWI> wi_f32_vals;
1459
+ sycl::vec<int8_t, ElementsPerWI> quantized_values;
1460
+
1461
+ auto float_ptr_offset = subgroup_id * QK8_1 + ElementsPerWI * wi_id;
1462
+ wi_f32_vals = *reinterpret_cast<const sycl::vec<float, ElementsPerWI> *>(x + float_ptr_offset);
1463
+
1464
+ float sum = 0.0f;
1465
+ float amax = 0.0f;
1466
+
1467
+ #pragma unroll(ElementsPerWI)
1468
+ for (int i = 0; i < ElementsPerWI; i++) {
1469
+ sum += wi_f32_vals[i];
1470
+ amax = sycl::fmax(amax, sycl::fabs(wi_f32_vals[i]));
1471
+ quantized_values[i] = 0;
1472
+ }
1473
+ sum = sycl::reduce_over_group(it.get_group(), sum, sycl::plus<float>());
1474
+ amax = sycl::reduce_over_group(it.get_group(), amax, sycl::maximum<float>());
1475
+ float d = amax == 0 ? 1 : amax / 127;
1476
+
1477
+ #pragma unroll(ElementsPerWI)
1478
+ for (int i = 0; i < ElementsPerWI; i++) {
1479
+ quantized_values[i] = sycl::round(wi_f32_vals[i] / d);
1480
+ }
1481
+
1482
+ d = amax == 0 ? 0 : d;
1483
+
1484
+ *reinterpret_cast<sycl::vec<int8_t, ElementsPerWI> *>(quant_ptr) = quantized_values;
1485
+ if (wi_id == 0) {
1486
+ *ds_ptr = sycl::half2(sycl::half(d), sycl::half(sum));
1487
+ }
1488
+ }
1489
+
1490
  static void mul_mat_p021_f16_f32(
1491
  const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst,
1492
  const int ncols_x, const int nrows_x, const int nchannels_x, const int nchannels_y,
 
1771
  o_ptr[cur_oh * ow + cur_ow] = res;
1772
  }
1773
 
1774
+ static void quantize_row_q8_1_sycl(const float * x, void * vy, const int kx, const int ky, const int kx_padded,
1775
+ bool reorder_q8_tensor, queue_ptr stream) {
1776
+ if (reorder_q8_tensor) {
1777
+ auto local_range = std::size_t(WARP_SIZE);
1778
+ auto num_quant_blocks = ky * (kx / QK8_1);
1779
+ auto global_range = num_quant_blocks * local_range;
1780
+ stream->parallel_for(sycl::nd_range<1>({ global_range }, { local_range }),
1781
+ [=](sycl::nd_item<1> it) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
1782
+ quantize_and_reorder_q8_1<QK8_1 / WARP_SIZE>(x, vy, kx, kx_padded, it);
1783
+ });
1784
+ } else {
1785
+ const int block_num_x = (kx_padded + SYCL_QUANTIZE_BLOCK_SIZE - 1) / SYCL_QUANTIZE_BLOCK_SIZE;
1786
+ const sycl::range<3> num_blocks(1, ky, block_num_x);
1787
+ int constexpr QUANT_BLOCK_TILE = QK8_1 / WARP_SIZE;
1788
+ static_assert(QK8_1 % WARP_SIZE == 0);
1789
+ const sycl::range<3> block_size(1, 1, SYCL_QUANTIZE_BLOCK_SIZE / QUANT_BLOCK_TILE);
1790
+ {
1791
+ dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
1792
 
1793
+ stream->parallel_for(sycl::nd_range<3>(num_blocks * block_size, block_size),
1794
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
1795
+ quantize_q8_1<QUANT_BLOCK_TILE>(x, vy, kx, kx_padded, item_ct1);
1796
+ });
1797
+ }
1798
  }
1799
  }
1800
 
 
2506
  dev[i].src1_ddq = dev[i].src1_ddq_alloc.alloc(ctx.pool(i), nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs);
2507
 
2508
  if (src1_on_device && src1_is_contiguous) {
2509
+ bool reorder_q8_tensor = src0->extra && ((ggml_tensor_extra_gpu *)src0->extra)->optimized_feature.reorder;
2510
  scope_op_debug_print scope_dbg_print(__func__, "/quantize_row_q8_1_sycl", dst,
2511
  /*num_src=*/2, " : converting src1 to Q8_1");
2512
+ quantize_row_q8_1_sycl(dev[i].src1_ddf, dev[i].src1_ddq, ne10, nrows1, src1_padded_col_size, reorder_q8_tensor, stream);
2513
  /*
2514
  DPCT1010:90: SYCL uses exceptions to report errors and does not
2515
  use the error codes. The call was replaced with 0. You need to
 
2615
  if (convert_src1_to_q8_1 && !src1_is_contiguous) {
2616
  scope_op_debug_print scope_dbg_print(__func__, "/quantize_row_q8_1_sycl", dst,
2617
  /*num_src=*/2, " : converting src1 to Q8_1");
2618
+ quantize_row_q8_1_sycl(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, src1_padded_col_size, false, stream);
2619
  /*
2620
  DPCT1010:92: SYCL uses exceptions to report errors and does
2621
  not use the error codes. The call was replaced with 0. You
ggml/src/ggml-sycl/mmvq.cpp CHANGED
@@ -29,8 +29,6 @@ static void mul_mat_vec_q_reorder(const void * __restrict__ vx, const void * __r
29
  static_assert(blocks_per_subgroup > 0);
30
  static_assert(block_elements_per_subgroup > 0);
31
 
32
- const block_q8_1 * y = (const block_q8_1 *) vy;
33
-
34
  float partial_sum = 0.0f;
35
  for (int i = sg.get_local_linear_id() / block_elements_per_subgroup; i < blocks_per_row; i += blocks_per_subgroup) {
36
  const int ibx = row * blocks_per_row + i; // x block index
@@ -40,13 +38,15 @@ static void mul_mat_vec_q_reorder(const void * __restrict__ vx, const void * __r
40
 
41
  // Y block index that aligns with ibx
42
  const int iby = i * block_type::block_to_q8_1_ratio();
 
 
43
 
44
  #pragma unroll
45
  for (int elem = 0; elem < block_elements_per_subgroup; elem += WARP_SIZE) {
46
  // x block quant index when casting the quants to int
47
  const int iqs = elem + block_traits::vdr_mmvq * (sg.get_local_linear_id() % block_elements_per_subgroup);
48
 
49
- partial_sum += reorder_vec_dot_q_sycl()(vx, bx_offset, d_offset, &y[iby], iqs, nblocks);
50
  }
51
  }
52
 
 
29
  static_assert(blocks_per_subgroup > 0);
30
  static_assert(block_elements_per_subgroup > 0);
31
 
 
 
32
  float partial_sum = 0.0f;
33
  for (int i = sg.get_local_linear_id() / block_elements_per_subgroup; i < blocks_per_row; i += blocks_per_subgroup) {
34
  const int ibx = row * blocks_per_row + i; // x block index
 
38
 
39
  // Y block index that aligns with ibx
40
  const int iby = i * block_type::block_to_q8_1_ratio();
41
+ const int8_t* q8_1_quant_ptr = (const int8_t*)vy + iby * QK8_1;
42
+ const sycl::half2* q8_1_ds_ptr = (const sycl::half2*)((const char*)vy + ncols + iby * sizeof(sycl::half2));
43
 
44
  #pragma unroll
45
  for (int elem = 0; elem < block_elements_per_subgroup; elem += WARP_SIZE) {
46
  // x block quant index when casting the quants to int
47
  const int iqs = elem + block_traits::vdr_mmvq * (sg.get_local_linear_id() % block_elements_per_subgroup);
48
 
49
+ partial_sum += reorder_vec_dot_q_sycl()(vx, bx_offset, d_offset, q8_1_quant_ptr, q8_1_ds_ptr, iqs, nblocks);
50
  }
51
  }
52
 
ggml/src/ggml-sycl/vecdotq.hpp CHANGED
@@ -285,21 +285,21 @@ template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q4_0> {
285
  }
286
 
287
  __dpct_inline__ float operator()(const void * __restrict__ vbq, const int ibx_offset, const int d_offset,
288
- const block_q8_1 * __restrict__ bq8_1, const int & iqs, int /* nblocks */) {
289
  const uint8_t * bq4_0 = static_cast<const uint8_t *>(vbq) + ibx_offset;
290
  const ggml_half d = *(reinterpret_cast<const ggml_half *>(static_cast<const uint8_t *>(vbq) + d_offset));
291
  int v[q4_0_traits::vdr_mmvq];
292
  int u[2 * q4_0_traits::vdr_mmvq];
293
 
294
- #pragma unroll
295
 
 
296
  for (size_t i = 0; i < q4_0_traits::vdr_mmvq; ++i) {
297
  v[i] = get_int_from_uint8(bq4_0, iqs + i);
298
- u[2 * i + 0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
299
- u[2 * i + 1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + q4_0_traits::qi);
300
  }
301
 
302
- return vec_dot_q4_0_q8_1_impl(v, u, d, bq8_1->ds);
303
  };
304
  };
305
 
@@ -347,7 +347,7 @@ template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q4_K> {
347
  using q4_k_traits = typename q4_k_block::traits;
348
 
349
  float operator()(const void * __restrict__ vbq, const int ibx_offset, const int d_offset,
350
- const block_q8_1 * __restrict__ bq8_1, const int & iqs, int nblocks) {
351
  const int ib = ibx_offset / (QK_K / 2);
352
 
353
  const uint8_t * base = static_cast<const uint8_t *>(vbq);
@@ -360,7 +360,38 @@ template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q4_K> {
360
  const int * q4 = (const int *) (qs + 16 * bq8_offset + 4 * ((iqs / 2) % 4));
361
  const uint16_t * scales = (const uint16_t *) scs;
362
 
363
- return vec_dot_q4_K_q8_1_common(q4, scales, *dms, bq8_1, iqs);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
364
  }
365
  };
366
 
 
285
  }
286
 
287
  __dpct_inline__ float operator()(const void * __restrict__ vbq, const int ibx_offset, const int d_offset,
288
+ const int8_t* q8_1_quant_ptr, const sycl::half2* q8_1_ds, const int & iqs, int /* nblocks */) {
289
  const uint8_t * bq4_0 = static_cast<const uint8_t *>(vbq) + ibx_offset;
290
  const ggml_half d = *(reinterpret_cast<const ggml_half *>(static_cast<const uint8_t *>(vbq) + d_offset));
291
  int v[q4_0_traits::vdr_mmvq];
292
  int u[2 * q4_0_traits::vdr_mmvq];
293
 
 
294
 
295
+ #pragma unroll
296
  for (size_t i = 0; i < q4_0_traits::vdr_mmvq; ++i) {
297
  v[i] = get_int_from_uint8(bq4_0, iqs + i);
298
+ u[2 * i + 0] = get_int_from_int8_aligned(q8_1_quant_ptr, iqs + i);
299
+ u[2 * i + 1] = get_int_from_int8_aligned(q8_1_quant_ptr, iqs + i + q4_0_traits::qi);
300
  }
301
 
302
+ return vec_dot_q4_0_q8_1_impl(v, u, d, *q8_1_ds);
303
  };
304
  };
305
 
 
347
  using q4_k_traits = typename q4_k_block::traits;
348
 
349
  float operator()(const void * __restrict__ vbq, const int ibx_offset, const int d_offset,
350
+ const int8_t* q8_1_quant_ptr, const sycl::half2* q8_1_ds, const int & iqs, int nblocks) {
351
  const int ib = ibx_offset / (QK_K / 2);
352
 
353
  const uint8_t * base = static_cast<const uint8_t *>(vbq);
 
360
  const int * q4 = (const int *) (qs + 16 * bq8_offset + 4 * ((iqs / 2) % 4));
361
  const uint16_t * scales = (const uint16_t *) scs;
362
 
363
+ int v[2];
364
+ int u[2 * QR4_K];
365
+ float d8[QR4_K];
366
+
367
+ v[0] = q4[0];
368
+ v[1] = q4[4];
369
+
370
+ uint16_t aux[2];
371
+ const int j = (QR4_K * ((iqs / 2) / (QI8_1 / 2))) / 2;
372
+ if (j < 2) {
373
+ aux[0] = scales[j + 0] & 0x3f3f;
374
+ aux[1] = scales[j + 2] & 0x3f3f;
375
+ } else {
376
+ aux[0] = ((scales[j + 2] >> 0) & 0x0f0f) | ((scales[j - 2] & 0xc0c0) >> 2);
377
+ aux[1] = ((scales[j + 2] >> 4) & 0x0f0f) | ((scales[j - 0] & 0xc0c0) >> 2);
378
+ }
379
+
380
+ const uint8_t * sc = (const uint8_t *) aux;
381
+ const uint8_t * m = sc + 2;
382
+
383
+ for (int i = 0; i < QR4_K; ++i) {
384
+ const int8_t* quant_base_ptr = q8_1_quant_ptr + (bq8_offset + i) * QK8_1;
385
+ sycl::half2 ds_values = *(q8_1_ds + bq8_offset + i);
386
+
387
+ d8[i] = ds_values[0];
388
+
389
+ const int * q8 = (const int *) quant_base_ptr + ((iqs / 2) % 4);
390
+ u[2 * i + 0] = q8[0];
391
+ u[2 * i + 1] = q8[4];
392
+ }
393
+
394
+ return vec_dot_q4_K_q8_1_impl_vmmq(v, u, sc, m, *dms, d8);
395
  }
396
  };
397