Spaces:
Running
Running
Faster AVX2 dot product for IQ2_XS (llama/5187)
Browse files* iq2xs: faster AVX2 dot product
* iq2xs: small AVX2 imrovement
* Speed up computing sign bits in AVX2 iq2_xs dot product
---------
Co-authored-by: Iwan Kawrakow <[email protected]>
Co-authored-by: Peter Reid <[email protected]>
- ggml-quants.c +76 -15
ggml-quants.c
CHANGED
|
@@ -8525,17 +8525,36 @@ void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * rest
|
|
| 8525 |
|
| 8526 |
const __m128i m4 = _mm_set1_epi8(0xf);
|
| 8527 |
const __m128i m1 = _mm_set1_epi8(1);
|
| 8528 |
-
const
|
| 8529 |
-
const
|
| 8530 |
|
| 8531 |
-
const
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8532 |
|
| 8533 |
uint64_t aux64;
|
| 8534 |
|
| 8535 |
// somewhat hacky, but gives a significant boost in performance
|
| 8536 |
-
|
| 8537 |
const uint16_t * gindex = (const uint16_t *)&aux_gindex;
|
| 8538 |
-
const uint16_t * sindex = (const uint16_t *)&aux_sindex;
|
| 8539 |
|
| 8540 |
__m256 accumf = _mm256_setzero_ps();
|
| 8541 |
for (int i = 0; i < nb; ++i) {
|
|
@@ -8550,26 +8569,68 @@ void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * rest
|
|
| 8550 |
|
| 8551 |
__m256i sumi1 = _mm256_setzero_si256();
|
| 8552 |
__m256i sumi2 = _mm256_setzero_si256();
|
| 8553 |
-
for (int ib32 = 0; ib32 < QK_K/32; ib32 +=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8554 |
const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
|
| 8555 |
const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
|
| 8556 |
-
const
|
| 8557 |
-
|
| 8558 |
-
|
| 8559 |
-
const __m256i q2_1 = _mm256_set_epi64x(iq2xs_grid[gindex[3]], iq2xs_grid[gindex[2]],
|
| 8560 |
-
|
| 8561 |
-
const __m256i
|
| 8562 |
-
|
| 8563 |
-
const __m256i
|
| 8564 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8565 |
const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1);
|
| 8566 |
const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2);
|
|
|
|
|
|
|
| 8567 |
|
| 8568 |
const __m256i sc1 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+0)));
|
| 8569 |
const __m256i sc2 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+1)));
|
|
|
|
|
|
|
| 8570 |
|
| 8571 |
sumi1 = _mm256_add_epi32(sumi1, _mm256_madd_epi16(dot1, sc1));
|
| 8572 |
sumi2 = _mm256_add_epi32(sumi2, _mm256_madd_epi16(dot2, sc2));
|
|
|
|
|
|
|
| 8573 |
}
|
| 8574 |
|
| 8575 |
accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf);
|
|
|
|
| 8525 |
|
| 8526 |
const __m128i m4 = _mm_set1_epi8(0xf);
|
| 8527 |
const __m128i m1 = _mm_set1_epi8(1);
|
| 8528 |
+
const __m256i m511 = _mm256_set1_epi16(511);
|
| 8529 |
+
const __m256i mone = _mm256_set1_epi8(1);
|
| 8530 |
|
| 8531 |
+
static const uint8_t k_bit_helper[32] = {
|
| 8532 |
+
0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
|
| 8533 |
+
0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
|
| 8534 |
+
};
|
| 8535 |
+
static const char block_sign_shuffle_mask_1[32] = {
|
| 8536 |
+
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02,
|
| 8537 |
+
0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06,
|
| 8538 |
+
};
|
| 8539 |
+
static const char block_sign_shuffle_mask_2[32] = {
|
| 8540 |
+
0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a,
|
| 8541 |
+
0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e,
|
| 8542 |
+
};
|
| 8543 |
+
static const uint8_t bit_selector_mask_bytes[32] = {
|
| 8544 |
+
0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
|
| 8545 |
+
0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
|
| 8546 |
+
};
|
| 8547 |
+
|
| 8548 |
+
const __m256i bit_helper = _mm256_loadu_si256((const __m256i*)k_bit_helper);
|
| 8549 |
+
const __m256i bit_selector_mask = _mm256_loadu_si256((const __m256i*)bit_selector_mask_bytes);
|
| 8550 |
+
const __m256i block_sign_shuffle_1 = _mm256_loadu_si256((const __m256i*)block_sign_shuffle_mask_1);
|
| 8551 |
+
const __m256i block_sign_shuffle_2 = _mm256_loadu_si256((const __m256i*)block_sign_shuffle_mask_2);
|
| 8552 |
|
| 8553 |
uint64_t aux64;
|
| 8554 |
|
| 8555 |
// somewhat hacky, but gives a significant boost in performance
|
| 8556 |
+
__m256i aux_gindex;
|
| 8557 |
const uint16_t * gindex = (const uint16_t *)&aux_gindex;
|
|
|
|
| 8558 |
|
| 8559 |
__m256 accumf = _mm256_setzero_ps();
|
| 8560 |
for (int i = 0; i < nb; ++i) {
|
|
|
|
| 8569 |
|
| 8570 |
__m256i sumi1 = _mm256_setzero_si256();
|
| 8571 |
__m256i sumi2 = _mm256_setzero_si256();
|
| 8572 |
+
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 4) {
|
| 8573 |
+
|
| 8574 |
+
const __m256i q2_data = _mm256_loadu_si256((const __m256i*)q2); q2 += 16;
|
| 8575 |
+
aux_gindex = _mm256_and_si256(q2_data, m511);
|
| 8576 |
+
|
| 8577 |
+
const __m256i partial_sign_bits = _mm256_srli_epi16(q2_data, 9);
|
| 8578 |
+
const __m256i partial_sign_bits_upper = _mm256_srli_epi16(q2_data, 13);
|
| 8579 |
+
const __m256i partial_sign_bits_for_counting = _mm256_xor_si256(partial_sign_bits, partial_sign_bits_upper);
|
| 8580 |
+
|
| 8581 |
+
const __m256i odd_bits = _mm256_shuffle_epi8(bit_helper, partial_sign_bits_for_counting);
|
| 8582 |
+
const __m256i full_sign_bits = _mm256_or_si256(partial_sign_bits, odd_bits);
|
| 8583 |
+
|
| 8584 |
const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
|
| 8585 |
const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
|
| 8586 |
+
const __m256i q8_3 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
|
| 8587 |
+
const __m256i q8_4 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
|
| 8588 |
+
|
| 8589 |
+
const __m256i q2_1 = _mm256_set_epi64x(iq2xs_grid[gindex[ 3]], iq2xs_grid[gindex[ 2]],
|
| 8590 |
+
iq2xs_grid[gindex[ 1]], iq2xs_grid[gindex[ 0]]);
|
| 8591 |
+
const __m256i q2_2 = _mm256_set_epi64x(iq2xs_grid[gindex[ 7]], iq2xs_grid[gindex[ 6]],
|
| 8592 |
+
iq2xs_grid[gindex[ 5]], iq2xs_grid[gindex[ 4]]);
|
| 8593 |
+
const __m256i q2_3 = _mm256_set_epi64x(iq2xs_grid[gindex[11]], iq2xs_grid[gindex[10]],
|
| 8594 |
+
iq2xs_grid[gindex[ 9]], iq2xs_grid[gindex[ 8]]);
|
| 8595 |
+
const __m256i q2_4 = _mm256_set_epi64x(iq2xs_grid[gindex[15]], iq2xs_grid[gindex[14]],
|
| 8596 |
+
iq2xs_grid[gindex[13]], iq2xs_grid[gindex[12]]);
|
| 8597 |
+
|
| 8598 |
+
const __m128i full_signs_l = _mm256_castsi256_si128(full_sign_bits);
|
| 8599 |
+
const __m128i full_signs_h = _mm256_extractf128_si256(full_sign_bits, 1);
|
| 8600 |
+
const __m256i full_signs_1 = _mm256_set_m128i(full_signs_l, full_signs_l);
|
| 8601 |
+
const __m256i full_signs_2 = _mm256_set_m128i(full_signs_h, full_signs_h);
|
| 8602 |
+
|
| 8603 |
+
__m256i signs;
|
| 8604 |
+
signs = _mm256_shuffle_epi8(full_signs_1, block_sign_shuffle_1);
|
| 8605 |
+
signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask);
|
| 8606 |
+
const __m256i q8s_1 = _mm256_sign_epi8(q8_1, _mm256_or_si256(signs, mone));
|
| 8607 |
+
|
| 8608 |
+
signs = _mm256_shuffle_epi8(full_signs_1, block_sign_shuffle_2);
|
| 8609 |
+
signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask);
|
| 8610 |
+
const __m256i q8s_2 = _mm256_sign_epi8(q8_2, _mm256_or_si256(signs, mone));
|
| 8611 |
+
|
| 8612 |
+
signs = _mm256_shuffle_epi8(full_signs_2, block_sign_shuffle_1);
|
| 8613 |
+
signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask);
|
| 8614 |
+
const __m256i q8s_3 = _mm256_sign_epi8(q8_3, _mm256_or_si256(signs, mone));
|
| 8615 |
+
|
| 8616 |
+
signs = _mm256_shuffle_epi8(full_signs_2, block_sign_shuffle_2);
|
| 8617 |
+
signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask);
|
| 8618 |
+
const __m256i q8s_4 = _mm256_sign_epi8(q8_4, _mm256_or_si256(signs, mone));
|
| 8619 |
+
|
| 8620 |
const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1);
|
| 8621 |
const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2);
|
| 8622 |
+
const __m256i dot3 = _mm256_maddubs_epi16(q2_3, q8s_3);
|
| 8623 |
+
const __m256i dot4 = _mm256_maddubs_epi16(q2_4, q8s_4);
|
| 8624 |
|
| 8625 |
const __m256i sc1 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+0)));
|
| 8626 |
const __m256i sc2 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+1)));
|
| 8627 |
+
const __m256i sc3 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+2)));
|
| 8628 |
+
const __m256i sc4 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+3)));
|
| 8629 |
|
| 8630 |
sumi1 = _mm256_add_epi32(sumi1, _mm256_madd_epi16(dot1, sc1));
|
| 8631 |
sumi2 = _mm256_add_epi32(sumi2, _mm256_madd_epi16(dot2, sc2));
|
| 8632 |
+
sumi1 = _mm256_add_epi32(sumi1, _mm256_madd_epi16(dot3, sc3));
|
| 8633 |
+
sumi2 = _mm256_add_epi32(sumi2, _mm256_madd_epi16(dot4, sc4));
|
| 8634 |
}
|
| 8635 |
|
| 8636 |
accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf);
|