Kawrakow ikawrakow PeterReid commited on
Commit
187ae44
·
unverified ·
1 Parent(s): 4649943

Faster AVX2 dot product for IQ2_XS (llama/5187)

Browse files

* iq2xs: faster AVX2 dot product

* iq2xs: small AVX2 imrovement

* Speed up computing sign bits in AVX2 iq2_xs dot product

---------

Co-authored-by: Iwan Kawrakow <[email protected]>
Co-authored-by: Peter Reid <[email protected]>

Files changed (1) hide show
  1. ggml-quants.c +76 -15
ggml-quants.c CHANGED
@@ -8525,17 +8525,36 @@ void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * rest
8525
 
8526
  const __m128i m4 = _mm_set1_epi8(0xf);
8527
  const __m128i m1 = _mm_set1_epi8(1);
8528
- const __m128i m511 = _mm_set1_epi16(511);
8529
- const __m128i m127 = _mm_set1_epi16(127);
8530
 
8531
- const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8532
 
8533
  uint64_t aux64;
8534
 
8535
  // somewhat hacky, but gives a significant boost in performance
8536
- __m128i aux_gindex, aux_sindex;
8537
  const uint16_t * gindex = (const uint16_t *)&aux_gindex;
8538
- const uint16_t * sindex = (const uint16_t *)&aux_sindex;
8539
 
8540
  __m256 accumf = _mm256_setzero_ps();
8541
  for (int i = 0; i < nb; ++i) {
@@ -8550,26 +8569,68 @@ void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * rest
8550
 
8551
  __m256i sumi1 = _mm256_setzero_si256();
8552
  __m256i sumi2 = _mm256_setzero_si256();
8553
- for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
 
 
 
 
 
 
 
 
 
 
 
8554
  const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
8555
  const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
8556
- const __m128i q2_data = _mm_loadu_si128((const __m128i*)q2); q2 += 8;
8557
- aux_gindex = _mm_and_si128(q2_data, m511);
8558
- aux_sindex = _mm_and_si128(_mm_srli_epi16(q2_data, 9), m127);
8559
- const __m256i q2_1 = _mm256_set_epi64x(iq2xs_grid[gindex[3]], iq2xs_grid[gindex[2]], iq2xs_grid[gindex[1]], iq2xs_grid[gindex[0]]);
8560
- const __m256i q2_2 = _mm256_set_epi64x(iq2xs_grid[gindex[7]], iq2xs_grid[gindex[6]], iq2xs_grid[gindex[5]], iq2xs_grid[gindex[4]]);
8561
- const __m256i s2_1 = _mm256_set_epi64x(signs64[sindex[3]], signs64[sindex[2]], signs64[sindex[1]], signs64[sindex[0]]);
8562
- const __m256i s2_2 = _mm256_set_epi64x(signs64[sindex[7]], signs64[sindex[6]], signs64[sindex[5]], signs64[sindex[4]]);
8563
- const __m256i q8s_1 = _mm256_sign_epi8(q8_1, s2_1);
8564
- const __m256i q8s_2 = _mm256_sign_epi8(q8_2, s2_2);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8565
  const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1);
8566
  const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2);
 
 
8567
 
8568
  const __m256i sc1 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+0)));
8569
  const __m256i sc2 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+1)));
 
 
8570
 
8571
  sumi1 = _mm256_add_epi32(sumi1, _mm256_madd_epi16(dot1, sc1));
8572
  sumi2 = _mm256_add_epi32(sumi2, _mm256_madd_epi16(dot2, sc2));
 
 
8573
  }
8574
 
8575
  accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf);
 
8525
 
8526
  const __m128i m4 = _mm_set1_epi8(0xf);
8527
  const __m128i m1 = _mm_set1_epi8(1);
8528
+ const __m256i m511 = _mm256_set1_epi16(511);
8529
+ const __m256i mone = _mm256_set1_epi8(1);
8530
 
8531
+ static const uint8_t k_bit_helper[32] = {
8532
+ 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
8533
+ 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
8534
+ };
8535
+ static const char block_sign_shuffle_mask_1[32] = {
8536
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02,
8537
+ 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06,
8538
+ };
8539
+ static const char block_sign_shuffle_mask_2[32] = {
8540
+ 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a,
8541
+ 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e,
8542
+ };
8543
+ static const uint8_t bit_selector_mask_bytes[32] = {
8544
+ 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
8545
+ 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
8546
+ };
8547
+
8548
+ const __m256i bit_helper = _mm256_loadu_si256((const __m256i*)k_bit_helper);
8549
+ const __m256i bit_selector_mask = _mm256_loadu_si256((const __m256i*)bit_selector_mask_bytes);
8550
+ const __m256i block_sign_shuffle_1 = _mm256_loadu_si256((const __m256i*)block_sign_shuffle_mask_1);
8551
+ const __m256i block_sign_shuffle_2 = _mm256_loadu_si256((const __m256i*)block_sign_shuffle_mask_2);
8552
 
8553
  uint64_t aux64;
8554
 
8555
  // somewhat hacky, but gives a significant boost in performance
8556
+ __m256i aux_gindex;
8557
  const uint16_t * gindex = (const uint16_t *)&aux_gindex;
 
8558
 
8559
  __m256 accumf = _mm256_setzero_ps();
8560
  for (int i = 0; i < nb; ++i) {
 
8569
 
8570
  __m256i sumi1 = _mm256_setzero_si256();
8571
  __m256i sumi2 = _mm256_setzero_si256();
8572
+ for (int ib32 = 0; ib32 < QK_K/32; ib32 += 4) {
8573
+
8574
+ const __m256i q2_data = _mm256_loadu_si256((const __m256i*)q2); q2 += 16;
8575
+ aux_gindex = _mm256_and_si256(q2_data, m511);
8576
+
8577
+ const __m256i partial_sign_bits = _mm256_srli_epi16(q2_data, 9);
8578
+ const __m256i partial_sign_bits_upper = _mm256_srli_epi16(q2_data, 13);
8579
+ const __m256i partial_sign_bits_for_counting = _mm256_xor_si256(partial_sign_bits, partial_sign_bits_upper);
8580
+
8581
+ const __m256i odd_bits = _mm256_shuffle_epi8(bit_helper, partial_sign_bits_for_counting);
8582
+ const __m256i full_sign_bits = _mm256_or_si256(partial_sign_bits, odd_bits);
8583
+
8584
  const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
8585
  const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
8586
+ const __m256i q8_3 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
8587
+ const __m256i q8_4 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
8588
+
8589
+ const __m256i q2_1 = _mm256_set_epi64x(iq2xs_grid[gindex[ 3]], iq2xs_grid[gindex[ 2]],
8590
+ iq2xs_grid[gindex[ 1]], iq2xs_grid[gindex[ 0]]);
8591
+ const __m256i q2_2 = _mm256_set_epi64x(iq2xs_grid[gindex[ 7]], iq2xs_grid[gindex[ 6]],
8592
+ iq2xs_grid[gindex[ 5]], iq2xs_grid[gindex[ 4]]);
8593
+ const __m256i q2_3 = _mm256_set_epi64x(iq2xs_grid[gindex[11]], iq2xs_grid[gindex[10]],
8594
+ iq2xs_grid[gindex[ 9]], iq2xs_grid[gindex[ 8]]);
8595
+ const __m256i q2_4 = _mm256_set_epi64x(iq2xs_grid[gindex[15]], iq2xs_grid[gindex[14]],
8596
+ iq2xs_grid[gindex[13]], iq2xs_grid[gindex[12]]);
8597
+
8598
+ const __m128i full_signs_l = _mm256_castsi256_si128(full_sign_bits);
8599
+ const __m128i full_signs_h = _mm256_extractf128_si256(full_sign_bits, 1);
8600
+ const __m256i full_signs_1 = _mm256_set_m128i(full_signs_l, full_signs_l);
8601
+ const __m256i full_signs_2 = _mm256_set_m128i(full_signs_h, full_signs_h);
8602
+
8603
+ __m256i signs;
8604
+ signs = _mm256_shuffle_epi8(full_signs_1, block_sign_shuffle_1);
8605
+ signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask);
8606
+ const __m256i q8s_1 = _mm256_sign_epi8(q8_1, _mm256_or_si256(signs, mone));
8607
+
8608
+ signs = _mm256_shuffle_epi8(full_signs_1, block_sign_shuffle_2);
8609
+ signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask);
8610
+ const __m256i q8s_2 = _mm256_sign_epi8(q8_2, _mm256_or_si256(signs, mone));
8611
+
8612
+ signs = _mm256_shuffle_epi8(full_signs_2, block_sign_shuffle_1);
8613
+ signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask);
8614
+ const __m256i q8s_3 = _mm256_sign_epi8(q8_3, _mm256_or_si256(signs, mone));
8615
+
8616
+ signs = _mm256_shuffle_epi8(full_signs_2, block_sign_shuffle_2);
8617
+ signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask);
8618
+ const __m256i q8s_4 = _mm256_sign_epi8(q8_4, _mm256_or_si256(signs, mone));
8619
+
8620
  const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1);
8621
  const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2);
8622
+ const __m256i dot3 = _mm256_maddubs_epi16(q2_3, q8s_3);
8623
+ const __m256i dot4 = _mm256_maddubs_epi16(q2_4, q8s_4);
8624
 
8625
  const __m256i sc1 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+0)));
8626
  const __m256i sc2 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+1)));
8627
+ const __m256i sc3 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+2)));
8628
+ const __m256i sc4 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+3)));
8629
 
8630
  sumi1 = _mm256_add_epi32(sumi1, _mm256_madd_epi16(dot1, sc1));
8631
  sumi2 = _mm256_add_epi32(sumi2, _mm256_madd_epi16(dot2, sc2));
8632
+ sumi1 = _mm256_add_epi32(sumi1, _mm256_madd_epi16(dot3, sc3));
8633
+ sumi2 = _mm256_add_epi32(sumi2, _mm256_madd_epi16(dot4, sc4));
8634
  }
8635
 
8636
  accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf);