Rémy O commited on
Commit
591cbfb
·
1 Parent(s): 092277a

ggml-cpu: faster AVX2 variant for IQ1_M (llama/12216)

Browse files
Files changed (1) hide show
  1. ggml/src/ggml-cpu/ggml-cpu-quants.c +17 -5
ggml/src/ggml-cpu/ggml-cpu-quants.c CHANGED
@@ -11718,9 +11718,12 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const
11718
 
11719
  #elif defined __AVX2__
11720
 
11721
- const __m256i mask = _mm256_set1_epi16(2 * 0x7);
11722
  const __m256i mone = _mm256_set1_epi16(1);
11723
  const __m256i mone8 = _mm256_set1_epi8(1);
 
 
 
11724
 
11725
  __m256 accum1 = _mm256_setzero_ps();
11726
  __m256 accum2 = _mm256_setzero_ps();
@@ -11732,6 +11735,14 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const
11732
  const uint16_t * sc = (const uint16_t *)x[i].scales;
11733
 
11734
  scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
 
 
 
 
 
 
 
 
11735
 
11736
  __m256i sumi1 = _mm256_setzero_si256();
11737
  __m256i sumi2 = _mm256_setzero_si256();
@@ -11777,11 +11788,12 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const
11777
  const __m256i dot3 = _mm256_maddubs_epi16(mone8, _mm256_sign_epi8(q8b_1, delta1));
11778
  const __m256i dot4 = _mm256_maddubs_epi16(mone8, _mm256_sign_epi8(q8b_2, delta2));
11779
 
11780
- __m256i scale1 = MM256_SET_M128I(_mm_set1_epi16(sc[ib/2] >> 2), _mm_set1_epi16(sc[ib/2] << 1));
11781
- __m256i scale2 = MM256_SET_M128I(_mm_set1_epi16(sc[ib/2] >> 8), _mm_set1_epi16(sc[ib/2] >> 5));
 
 
 
11782
 
11783
- scale1 = _mm256_add_epi16(_mm256_and_si256(scale1, mask), mone);
11784
- scale2 = _mm256_add_epi16(_mm256_and_si256(scale2, mask), mone);
11785
  const __m256i p1 = _mm256_madd_epi16(dot1, scale1);
11786
  const __m256i p2 = _mm256_madd_epi16(dot2, scale2);
11787
  const __m256i p3 = _mm256_madd_epi16(dot3, scale1);
 
11718
 
11719
  #elif defined __AVX2__
11720
 
11721
+ const __m256i mask = _mm256_set1_epi16(0x7);
11722
  const __m256i mone = _mm256_set1_epi16(1);
11723
  const __m256i mone8 = _mm256_set1_epi8(1);
11724
+ const __m256i mtwo8 = _mm256_set1_epi8(2);
11725
+ // VPSHUFB cannot cross 128-bit lanes so odd shifts go to upper half.
11726
+ const __m256i scales_shift = _mm256_set_epi64x(9, 3, 6, 0);
11727
 
11728
  __m256 accum1 = _mm256_setzero_ps();
11729
  __m256 accum2 = _mm256_setzero_ps();
 
11735
  const uint16_t * sc = (const uint16_t *)x[i].scales;
11736
 
11737
  scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
11738
+ // Extract 3-bit scales (16 values)
11739
+ __m256i scales = _mm256_set1_epi64x(*(const uint64_t*)sc);
11740
+ scales = _mm256_srlv_epi64(scales, scales_shift);
11741
+ scales = _mm256_add_epi16(_mm256_slli_epi16(_mm256_and_si256(scales, mask), 1), mone);
11742
+
11743
+ // Indices to repeat each scale 8 times.
11744
+ __m256i scales_idx1 = _mm256_set1_epi16(0x0100);
11745
+ __m256i scales_idx2 = _mm256_add_epi8(scales_idx1, _mm256_set1_epi8(8));
11746
 
11747
  __m256i sumi1 = _mm256_setzero_si256();
11748
  __m256i sumi2 = _mm256_setzero_si256();
 
11788
  const __m256i dot3 = _mm256_maddubs_epi16(mone8, _mm256_sign_epi8(q8b_1, delta1));
11789
  const __m256i dot4 = _mm256_maddubs_epi16(mone8, _mm256_sign_epi8(q8b_2, delta2));
11790
 
11791
+ __m256i scale1 = _mm256_shuffle_epi8(scales, scales_idx1);
11792
+ __m256i scale2 = _mm256_shuffle_epi8(scales, scales_idx2);
11793
+
11794
+ scales_idx1 = _mm256_add_epi8(scales_idx1, mtwo8);
11795
+ scales_idx2 = _mm256_add_epi8(scales_idx2, mtwo8);
11796
 
 
 
11797
  const __m256i p1 = _mm256_madd_epi16(dot1, scale1);
11798
  const __m256i p2 = _mm256_madd_epi16(dot2, scale2);
11799
  const __m256i p3 = _mm256_madd_epi16(dot3, scale1);