Spaces:
Sleeping
Sleeping
Eve
commited on
Commit
·
b4c65b4
1
Parent(s):
59dd404
Q6_K AVX improvements (llama/10118)
Browse files* q6_k instruction reordering attempt
* better subtract method
* should be theoretically faster
small improvement with shuffle lut, likely because all loads are already done at that stage
* optimize bit fiddling
* handle -32 offset separately. bsums exists for a reason!
* use shift
* Update ggml-quants.c
* have to update ci macos version to 13 as 12 doesnt work now. 13 is still x86
- ggml/src/ggml-quants.c +37 -50
ggml/src/ggml-quants.c
CHANGED
|
@@ -9104,10 +9104,8 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
|
|
| 9104 |
|
| 9105 |
#elif defined __AVX__
|
| 9106 |
|
| 9107 |
-
const __m128i m4 = _mm_set1_epi8(0xF);
|
| 9108 |
const __m128i m3 = _mm_set1_epi8(3);
|
| 9109 |
-
const __m128i
|
| 9110 |
-
const __m128i m2 = _mm_set1_epi8(2);
|
| 9111 |
|
| 9112 |
__m256 acc = _mm256_setzero_ps();
|
| 9113 |
|
|
@@ -9119,12 +9117,20 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
|
|
| 9119 |
const uint8_t * restrict qh = x[i].qh;
|
| 9120 |
const int8_t * restrict q8 = y[i].qs;
|
| 9121 |
|
|
|
|
|
|
|
|
|
|
| 9122 |
const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales);
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9123 |
|
| 9124 |
__m128i sumi_0 = _mm_setzero_si128();
|
| 9125 |
__m128i sumi_1 = _mm_setzero_si128();
|
| 9126 |
|
| 9127 |
-
|
|
|
|
| 9128 |
for (int j = 0; j < QK_K/128; ++j) {
|
| 9129 |
|
| 9130 |
const __m128i q4bitsH_0 = _mm_loadu_si128((const __m128i*)qh); qh += 16;
|
|
@@ -9132,26 +9138,26 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
|
|
| 9132 |
|
| 9133 |
const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, m3), 4);
|
| 9134 |
const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, m3), 4);
|
| 9135 |
-
const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(
|
| 9136 |
-
const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(
|
| 9137 |
-
const __m128i q4h_4 =
|
| 9138 |
-
const __m128i q4h_5 =
|
| 9139 |
-
const __m128i q4h_6 =
|
| 9140 |
-
const __m128i q4h_7 =
|
| 9141 |
|
| 9142 |
const __m128i q4bits1_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
|
| 9143 |
const __m128i q4bits1_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
|
| 9144 |
const __m128i q4bits2_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
|
| 9145 |
const __m128i q4bits2_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
|
| 9146 |
|
| 9147 |
-
const __m128i q4_0 = _mm_or_si128(_mm_and_si128(q4bits1_0,
|
| 9148 |
-
const __m128i q4_1 = _mm_or_si128(_mm_and_si128(q4bits1_1,
|
| 9149 |
-
const __m128i q4_2 = _mm_or_si128(_mm_and_si128(q4bits2_0,
|
| 9150 |
-
const __m128i q4_3 = _mm_or_si128(_mm_and_si128(q4bits2_1,
|
| 9151 |
-
const __m128i q4_4 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_0, 4),
|
| 9152 |
-
const __m128i q4_5 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_1, 4),
|
| 9153 |
-
const __m128i q4_6 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_0, 4),
|
| 9154 |
-
const __m128i q4_7 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_1, 4),
|
| 9155 |
|
| 9156 |
const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
| 9157 |
const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
|
@@ -9162,15 +9168,6 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
|
|
| 9162 |
const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
| 9163 |
const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
| 9164 |
|
| 9165 |
-
__m128i q8s_0 = _mm_maddubs_epi16(m32s, q8_0);
|
| 9166 |
-
__m128i q8s_1 = _mm_maddubs_epi16(m32s, q8_1);
|
| 9167 |
-
__m128i q8s_2 = _mm_maddubs_epi16(m32s, q8_2);
|
| 9168 |
-
__m128i q8s_3 = _mm_maddubs_epi16(m32s, q8_3);
|
| 9169 |
-
__m128i q8s_4 = _mm_maddubs_epi16(m32s, q8_4);
|
| 9170 |
-
__m128i q8s_5 = _mm_maddubs_epi16(m32s, q8_5);
|
| 9171 |
-
__m128i q8s_6 = _mm_maddubs_epi16(m32s, q8_6);
|
| 9172 |
-
__m128i q8s_7 = _mm_maddubs_epi16(m32s, q8_7);
|
| 9173 |
-
|
| 9174 |
__m128i p16_0 = _mm_maddubs_epi16(q4_0, q8_0);
|
| 9175 |
__m128i p16_1 = _mm_maddubs_epi16(q4_1, q8_1);
|
| 9176 |
__m128i p16_2 = _mm_maddubs_epi16(q4_2, q8_2);
|
|
@@ -9180,32 +9177,20 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
|
|
| 9180 |
__m128i p16_6 = _mm_maddubs_epi16(q4_6, q8_6);
|
| 9181 |
__m128i p16_7 = _mm_maddubs_epi16(q4_7, q8_7);
|
| 9182 |
|
| 9183 |
-
|
| 9184 |
-
|
| 9185 |
-
|
| 9186 |
-
|
| 9187 |
-
|
| 9188 |
-
p16_5 = _mm_sub_epi16(p16_5, q8s_5);
|
| 9189 |
-
p16_6 = _mm_sub_epi16(p16_6, q8s_6);
|
| 9190 |
-
p16_7 = _mm_sub_epi16(p16_7, q8s_7);
|
| 9191 |
-
|
| 9192 |
-
const __m128i scale_0 = _mm_shuffle_epi8(scales, shuffle);
|
| 9193 |
-
shuffle = _mm_add_epi8(shuffle, m2);
|
| 9194 |
-
const __m128i scale_1 = _mm_shuffle_epi8(scales, shuffle);
|
| 9195 |
-
shuffle = _mm_add_epi8(shuffle, m2);
|
| 9196 |
-
const __m128i scale_2 = _mm_shuffle_epi8(scales, shuffle);
|
| 9197 |
-
shuffle = _mm_add_epi8(shuffle, m2);
|
| 9198 |
-
const __m128i scale_3 = _mm_shuffle_epi8(scales, shuffle);
|
| 9199 |
-
shuffle = _mm_add_epi8(shuffle, m2);
|
| 9200 |
|
| 9201 |
p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0);
|
| 9202 |
-
p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(
|
| 9203 |
p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2);
|
| 9204 |
-
p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(
|
| 9205 |
p16_4 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_2), p16_4);
|
| 9206 |
-
p16_5 = _mm_madd_epi16(_mm_cvtepi8_epi16(
|
| 9207 |
p16_6 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_3), p16_6);
|
| 9208 |
-
p16_7 = _mm_madd_epi16(_mm_cvtepi8_epi16(
|
| 9209 |
|
| 9210 |
sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
|
| 9211 |
sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3));
|
|
@@ -9214,8 +9199,10 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
|
|
| 9214 |
|
| 9215 |
}
|
| 9216 |
|
| 9217 |
-
|
| 9218 |
-
|
|
|
|
|
|
|
| 9219 |
}
|
| 9220 |
|
| 9221 |
*s = hsum_float_8(acc);
|
|
|
|
| 9104 |
|
| 9105 |
#elif defined __AVX__
|
| 9106 |
|
|
|
|
| 9107 |
const __m128i m3 = _mm_set1_epi8(3);
|
| 9108 |
+
const __m128i m15 = _mm_set1_epi8(15);
|
|
|
|
| 9109 |
|
| 9110 |
__m256 acc = _mm256_setzero_ps();
|
| 9111 |
|
|
|
|
| 9117 |
const uint8_t * restrict qh = x[i].qh;
|
| 9118 |
const int8_t * restrict q8 = y[i].qs;
|
| 9119 |
|
| 9120 |
+
// handle the q6_k -32 offset separately using bsums
|
| 9121 |
+
const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)y[i].bsums);
|
| 9122 |
+
const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)y[i].bsums + 1);
|
| 9123 |
const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales);
|
| 9124 |
+
const __m128i scales_16_0 = _mm_cvtepi8_epi16(scales);
|
| 9125 |
+
const __m128i scales_16_1 = _mm_cvtepi8_epi16(_mm_bsrli_si128(scales, 8));
|
| 9126 |
+
const __m128i q8sclsub_0 = _mm_slli_epi32(_mm_madd_epi16(q8sums_0, scales_16_0), 5);
|
| 9127 |
+
const __m128i q8sclsub_1 = _mm_slli_epi32(_mm_madd_epi16(q8sums_1, scales_16_1), 5);
|
| 9128 |
|
| 9129 |
__m128i sumi_0 = _mm_setzero_si128();
|
| 9130 |
__m128i sumi_1 = _mm_setzero_si128();
|
| 9131 |
|
| 9132 |
+
int is = 0;
|
| 9133 |
+
|
| 9134 |
for (int j = 0; j < QK_K/128; ++j) {
|
| 9135 |
|
| 9136 |
const __m128i q4bitsH_0 = _mm_loadu_si128((const __m128i*)qh); qh += 16;
|
|
|
|
| 9138 |
|
| 9139 |
const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, m3), 4);
|
| 9140 |
const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, m3), 4);
|
| 9141 |
+
const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, _mm_set1_epi8(12)), 2);
|
| 9142 |
+
const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, _mm_set1_epi8(12)), 2);
|
| 9143 |
+
const __m128i q4h_4 = _mm_and_si128(q4bitsH_0, _mm_set1_epi8(48));
|
| 9144 |
+
const __m128i q4h_5 = _mm_and_si128(q4bitsH_1, _mm_set1_epi8(48));
|
| 9145 |
+
const __m128i q4h_6 = _mm_srli_epi16(_mm_and_si128(q4bitsH_0, _mm_set1_epi8(-64)), 2);
|
| 9146 |
+
const __m128i q4h_7 = _mm_srli_epi16(_mm_and_si128(q4bitsH_1, _mm_set1_epi8(-64)), 2);
|
| 9147 |
|
| 9148 |
const __m128i q4bits1_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
|
| 9149 |
const __m128i q4bits1_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
|
| 9150 |
const __m128i q4bits2_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
|
| 9151 |
const __m128i q4bits2_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
|
| 9152 |
|
| 9153 |
+
const __m128i q4_0 = _mm_or_si128(_mm_and_si128(q4bits1_0, m15), q4h_0);
|
| 9154 |
+
const __m128i q4_1 = _mm_or_si128(_mm_and_si128(q4bits1_1, m15), q4h_1);
|
| 9155 |
+
const __m128i q4_2 = _mm_or_si128(_mm_and_si128(q4bits2_0, m15), q4h_2);
|
| 9156 |
+
const __m128i q4_3 = _mm_or_si128(_mm_and_si128(q4bits2_1, m15), q4h_3);
|
| 9157 |
+
const __m128i q4_4 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_0, 4), m15), q4h_4);
|
| 9158 |
+
const __m128i q4_5 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_1, 4), m15), q4h_5);
|
| 9159 |
+
const __m128i q4_6 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_0, 4), m15), q4h_6);
|
| 9160 |
+
const __m128i q4_7 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_1, 4), m15), q4h_7);
|
| 9161 |
|
| 9162 |
const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
| 9163 |
const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
|
|
|
| 9168 |
const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
| 9169 |
const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
| 9170 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9171 |
__m128i p16_0 = _mm_maddubs_epi16(q4_0, q8_0);
|
| 9172 |
__m128i p16_1 = _mm_maddubs_epi16(q4_1, q8_1);
|
| 9173 |
__m128i p16_2 = _mm_maddubs_epi16(q4_2, q8_2);
|
|
|
|
| 9177 |
__m128i p16_6 = _mm_maddubs_epi16(q4_6, q8_6);
|
| 9178 |
__m128i p16_7 = _mm_maddubs_epi16(q4_7, q8_7);
|
| 9179 |
|
| 9180 |
+
const __m128i scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 0));
|
| 9181 |
+
const __m128i scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 1));
|
| 9182 |
+
const __m128i scale_2 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 2));
|
| 9183 |
+
const __m128i scale_3 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 3));
|
| 9184 |
+
is += 4;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9185 |
|
| 9186 |
p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0);
|
| 9187 |
+
p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_0, 8)), p16_1);
|
| 9188 |
p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2);
|
| 9189 |
+
p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_1, 8)), p16_3);
|
| 9190 |
p16_4 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_2), p16_4);
|
| 9191 |
+
p16_5 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_2, 8)), p16_5);
|
| 9192 |
p16_6 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_3), p16_6);
|
| 9193 |
+
p16_7 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_3, 8)), p16_7);
|
| 9194 |
|
| 9195 |
sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
|
| 9196 |
sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3));
|
|
|
|
| 9199 |
|
| 9200 |
}
|
| 9201 |
|
| 9202 |
+
sumi_0 = _mm_sub_epi32(sumi_0, q8sclsub_0);
|
| 9203 |
+
sumi_1 = _mm_sub_epi32(sumi_1, q8sclsub_1);
|
| 9204 |
+
const __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
|
| 9205 |
+
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(sumi)), acc);
|
| 9206 |
}
|
| 9207 |
|
| 9208 |
*s = hsum_float_8(acc);
|