Eve commited on
Commit
b4c65b4
·
1 Parent(s): 59dd404

Q6_K AVX improvements (llama/10118)

Browse files

* q6_k instruction reordering attempt

* better subtract method

* should be theoretically faster

small improvement with shuffle lut, likely because all loads are already done at that stage

* optimize bit fiddling

* handle -32 offset separately. bsums exists for a reason!

* use shift

* Update ggml-quants.c

* have to update ci macos version to 13 as 12 doesnt work now. 13 is still x86

Files changed (1) hide show
  1. ggml/src/ggml-quants.c +37 -50
ggml/src/ggml-quants.c CHANGED
@@ -9104,10 +9104,8 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
9104
 
9105
  #elif defined __AVX__
9106
 
9107
- const __m128i m4 = _mm_set1_epi8(0xF);
9108
  const __m128i m3 = _mm_set1_epi8(3);
9109
- const __m128i m32s = _mm_set1_epi8(32);
9110
- const __m128i m2 = _mm_set1_epi8(2);
9111
 
9112
  __m256 acc = _mm256_setzero_ps();
9113
 
@@ -9119,12 +9117,20 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
9119
  const uint8_t * restrict qh = x[i].qh;
9120
  const int8_t * restrict q8 = y[i].qs;
9121
 
 
 
 
9122
  const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales);
 
 
 
 
9123
 
9124
  __m128i sumi_0 = _mm_setzero_si128();
9125
  __m128i sumi_1 = _mm_setzero_si128();
9126
 
9127
- __m128i shuffle = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000);
 
9128
  for (int j = 0; j < QK_K/128; ++j) {
9129
 
9130
  const __m128i q4bitsH_0 = _mm_loadu_si128((const __m128i*)qh); qh += 16;
@@ -9132,26 +9138,26 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
9132
 
9133
  const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, m3), 4);
9134
  const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, m3), 4);
9135
- const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 2), m3), 4);
9136
- const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 2), m3), 4);
9137
- const __m128i q4h_4 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 4), m3), 4);
9138
- const __m128i q4h_5 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 4), m3), 4);
9139
- const __m128i q4h_6 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 6), m3), 4);
9140
- const __m128i q4h_7 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 6), m3), 4);
9141
 
9142
  const __m128i q4bits1_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
9143
  const __m128i q4bits1_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
9144
  const __m128i q4bits2_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
9145
  const __m128i q4bits2_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
9146
 
9147
- const __m128i q4_0 = _mm_or_si128(_mm_and_si128(q4bits1_0, m4), q4h_0);
9148
- const __m128i q4_1 = _mm_or_si128(_mm_and_si128(q4bits1_1, m4), q4h_1);
9149
- const __m128i q4_2 = _mm_or_si128(_mm_and_si128(q4bits2_0, m4), q4h_2);
9150
- const __m128i q4_3 = _mm_or_si128(_mm_and_si128(q4bits2_1, m4), q4h_3);
9151
- const __m128i q4_4 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_0, 4), m4), q4h_4);
9152
- const __m128i q4_5 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_1, 4), m4), q4h_5);
9153
- const __m128i q4_6 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_0, 4), m4), q4h_6);
9154
- const __m128i q4_7 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_1, 4), m4), q4h_7);
9155
 
9156
  const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
9157
  const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
@@ -9162,15 +9168,6 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
9162
  const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
9163
  const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
9164
 
9165
- __m128i q8s_0 = _mm_maddubs_epi16(m32s, q8_0);
9166
- __m128i q8s_1 = _mm_maddubs_epi16(m32s, q8_1);
9167
- __m128i q8s_2 = _mm_maddubs_epi16(m32s, q8_2);
9168
- __m128i q8s_3 = _mm_maddubs_epi16(m32s, q8_3);
9169
- __m128i q8s_4 = _mm_maddubs_epi16(m32s, q8_4);
9170
- __m128i q8s_5 = _mm_maddubs_epi16(m32s, q8_5);
9171
- __m128i q8s_6 = _mm_maddubs_epi16(m32s, q8_6);
9172
- __m128i q8s_7 = _mm_maddubs_epi16(m32s, q8_7);
9173
-
9174
  __m128i p16_0 = _mm_maddubs_epi16(q4_0, q8_0);
9175
  __m128i p16_1 = _mm_maddubs_epi16(q4_1, q8_1);
9176
  __m128i p16_2 = _mm_maddubs_epi16(q4_2, q8_2);
@@ -9180,32 +9177,20 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
9180
  __m128i p16_6 = _mm_maddubs_epi16(q4_6, q8_6);
9181
  __m128i p16_7 = _mm_maddubs_epi16(q4_7, q8_7);
9182
 
9183
- p16_0 = _mm_sub_epi16(p16_0, q8s_0);
9184
- p16_1 = _mm_sub_epi16(p16_1, q8s_1);
9185
- p16_2 = _mm_sub_epi16(p16_2, q8s_2);
9186
- p16_3 = _mm_sub_epi16(p16_3, q8s_3);
9187
- p16_4 = _mm_sub_epi16(p16_4, q8s_4);
9188
- p16_5 = _mm_sub_epi16(p16_5, q8s_5);
9189
- p16_6 = _mm_sub_epi16(p16_6, q8s_6);
9190
- p16_7 = _mm_sub_epi16(p16_7, q8s_7);
9191
-
9192
- const __m128i scale_0 = _mm_shuffle_epi8(scales, shuffle);
9193
- shuffle = _mm_add_epi8(shuffle, m2);
9194
- const __m128i scale_1 = _mm_shuffle_epi8(scales, shuffle);
9195
- shuffle = _mm_add_epi8(shuffle, m2);
9196
- const __m128i scale_2 = _mm_shuffle_epi8(scales, shuffle);
9197
- shuffle = _mm_add_epi8(shuffle, m2);
9198
- const __m128i scale_3 = _mm_shuffle_epi8(scales, shuffle);
9199
- shuffle = _mm_add_epi8(shuffle, m2);
9200
 
9201
  p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0);
9202
- p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_0, scale_0)), p16_1);
9203
  p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2);
9204
- p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_1, scale_1)), p16_3);
9205
  p16_4 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_2), p16_4);
9206
- p16_5 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_2, scale_2)), p16_5);
9207
  p16_6 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_3), p16_6);
9208
- p16_7 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_3, scale_3)), p16_7);
9209
 
9210
  sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
9211
  sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3));
@@ -9214,8 +9199,10 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
9214
 
9215
  }
9216
 
9217
- __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
9218
- acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi)), acc);
 
 
9219
  }
9220
 
9221
  *s = hsum_float_8(acc);
 
9104
 
9105
  #elif defined __AVX__
9106
 
 
9107
  const __m128i m3 = _mm_set1_epi8(3);
9108
+ const __m128i m15 = _mm_set1_epi8(15);
 
9109
 
9110
  __m256 acc = _mm256_setzero_ps();
9111
 
 
9117
  const uint8_t * restrict qh = x[i].qh;
9118
  const int8_t * restrict q8 = y[i].qs;
9119
 
9120
+ // handle the q6_k -32 offset separately using bsums
9121
+ const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)y[i].bsums);
9122
+ const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)y[i].bsums + 1);
9123
  const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales);
9124
+ const __m128i scales_16_0 = _mm_cvtepi8_epi16(scales);
9125
+ const __m128i scales_16_1 = _mm_cvtepi8_epi16(_mm_bsrli_si128(scales, 8));
9126
+ const __m128i q8sclsub_0 = _mm_slli_epi32(_mm_madd_epi16(q8sums_0, scales_16_0), 5);
9127
+ const __m128i q8sclsub_1 = _mm_slli_epi32(_mm_madd_epi16(q8sums_1, scales_16_1), 5);
9128
 
9129
  __m128i sumi_0 = _mm_setzero_si128();
9130
  __m128i sumi_1 = _mm_setzero_si128();
9131
 
9132
+ int is = 0;
9133
+
9134
  for (int j = 0; j < QK_K/128; ++j) {
9135
 
9136
  const __m128i q4bitsH_0 = _mm_loadu_si128((const __m128i*)qh); qh += 16;
 
9138
 
9139
  const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, m3), 4);
9140
  const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, m3), 4);
9141
+ const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, _mm_set1_epi8(12)), 2);
9142
+ const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, _mm_set1_epi8(12)), 2);
9143
+ const __m128i q4h_4 = _mm_and_si128(q4bitsH_0, _mm_set1_epi8(48));
9144
+ const __m128i q4h_5 = _mm_and_si128(q4bitsH_1, _mm_set1_epi8(48));
9145
+ const __m128i q4h_6 = _mm_srli_epi16(_mm_and_si128(q4bitsH_0, _mm_set1_epi8(-64)), 2);
9146
+ const __m128i q4h_7 = _mm_srli_epi16(_mm_and_si128(q4bitsH_1, _mm_set1_epi8(-64)), 2);
9147
 
9148
  const __m128i q4bits1_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
9149
  const __m128i q4bits1_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
9150
  const __m128i q4bits2_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
9151
  const __m128i q4bits2_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
9152
 
9153
+ const __m128i q4_0 = _mm_or_si128(_mm_and_si128(q4bits1_0, m15), q4h_0);
9154
+ const __m128i q4_1 = _mm_or_si128(_mm_and_si128(q4bits1_1, m15), q4h_1);
9155
+ const __m128i q4_2 = _mm_or_si128(_mm_and_si128(q4bits2_0, m15), q4h_2);
9156
+ const __m128i q4_3 = _mm_or_si128(_mm_and_si128(q4bits2_1, m15), q4h_3);
9157
+ const __m128i q4_4 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_0, 4), m15), q4h_4);
9158
+ const __m128i q4_5 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_1, 4), m15), q4h_5);
9159
+ const __m128i q4_6 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_0, 4), m15), q4h_6);
9160
+ const __m128i q4_7 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_1, 4), m15), q4h_7);
9161
 
9162
  const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
9163
  const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
 
9168
  const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
9169
  const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
9170
 
 
 
 
 
 
 
 
 
 
9171
  __m128i p16_0 = _mm_maddubs_epi16(q4_0, q8_0);
9172
  __m128i p16_1 = _mm_maddubs_epi16(q4_1, q8_1);
9173
  __m128i p16_2 = _mm_maddubs_epi16(q4_2, q8_2);
 
9177
  __m128i p16_6 = _mm_maddubs_epi16(q4_6, q8_6);
9178
  __m128i p16_7 = _mm_maddubs_epi16(q4_7, q8_7);
9179
 
9180
+ const __m128i scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 0));
9181
+ const __m128i scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 1));
9182
+ const __m128i scale_2 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 2));
9183
+ const __m128i scale_3 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 3));
9184
+ is += 4;
 
 
 
 
 
 
 
 
 
 
 
 
9185
 
9186
  p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0);
9187
+ p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_0, 8)), p16_1);
9188
  p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2);
9189
+ p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_1, 8)), p16_3);
9190
  p16_4 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_2), p16_4);
9191
+ p16_5 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_2, 8)), p16_5);
9192
  p16_6 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_3), p16_6);
9193
+ p16_7 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_3, 8)), p16_7);
9194
 
9195
  sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
9196
  sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3));
 
9199
 
9200
  }
9201
 
9202
+ sumi_0 = _mm_sub_epi32(sumi_0, q8sclsub_0);
9203
+ sumi_1 = _mm_sub_epi32(sumi_1, q8sclsub_1);
9204
+ const __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
9205
+ acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(sumi)), acc);
9206
  }
9207
 
9208
  *s = hsum_float_8(acc);