Spaces:
Running
arm64: optimize q6_k_q8_k kernel with i8mm (llama/13519)
Browse filesThis PR improves q6_k_q8_k gemm kernel with arm64 i8mm instruction.
Tested on neoverse-n2 with llama3 8b q6_k quantization model.
- 40% ~ 54% S_PP uplift for all batch sizes
- 16% ~ 47% S_TG uplift for batch size 4 and above
Perplexity doesn't change with this PR.
```
// tested on neoverse-n2
$ llama-batched-bench \
-m Meta-Llama-3-8B-Instruct-Q6_K.gguf \
--no-mmap -fa \
-c 8192 -b 4096 -ub 512 -npp 128 -ntg 128 \
-npl 1,2,4,8,16,32 \
-t 64
---------------------------------------------------------------------
| PP | TG | B | S_PP t/s | S_TG t/s |
| | | | original | this pr | original | this pr |
|-------|--------|------|----------|----------|----------|----------|
| 128 | 128 | 1 | 78.52 | 109.18 | 18.63 | 18.88 |
| 128 | 128 | 2 | 84.62 | 123.94 | 34.54 | 36.92 |
| 128 | 128 | 4 | 84.36 | 122.49 | 52.65 | 61.32 |
| 128 | 128 | 8 | 90.52 | 138.87 | 63.46 | 84.41 |
| 128 | 128 | 16 | 90.11 | 138.56 | 71.04 | 101.33 |
| 128 | 128 | 32 | 89.81 | 137.79 | 75.14 | 110.47 |
---------------------------------------------------------------------
```
|
@@ -8519,7 +8519,11 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
|
| 8519 |
|
| 8520 |
void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
| 8521 |
assert(n % QK_K == 0);
|
|
|
|
|
|
|
|
|
|
| 8522 |
assert(nrc == 1);
|
|
|
|
| 8523 |
UNUSED(nrc);
|
| 8524 |
UNUSED(bx);
|
| 8525 |
UNUSED(by);
|
|
@@ -8530,6 +8534,197 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
|
| 8530 |
|
| 8531 |
const int nb = n / QK_K;
|
| 8532 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8533 |
#ifdef __ARM_FEATURE_SVE
|
| 8534 |
const int vector_length = ggml_cpu_get_sve_cnt()*8;
|
| 8535 |
float sum = 0;
|
|
|
|
| 8519 |
|
| 8520 |
void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
| 8521 |
assert(n % QK_K == 0);
|
| 8522 |
+
#ifdef __ARM_FEATURE_MATMUL_INT8
|
| 8523 |
+
assert((nrc == 2) || (nrc == 1));
|
| 8524 |
+
#else
|
| 8525 |
assert(nrc == 1);
|
| 8526 |
+
#endif
|
| 8527 |
UNUSED(nrc);
|
| 8528 |
UNUSED(bx);
|
| 8529 |
UNUSED(by);
|
|
|
|
| 8534 |
|
| 8535 |
const int nb = n / QK_K;
|
| 8536 |
|
| 8537 |
+
#if defined(__ARM_FEATURE_MATMUL_INT8)
|
| 8538 |
+
if (nrc == 2) {
|
| 8539 |
+
const block_q6_K * GGML_RESTRICT x0 = x;
|
| 8540 |
+
const block_q6_K * GGML_RESTRICT x1 = (const block_q6_K *) ((const uint8_t *)vx + bx);
|
| 8541 |
+
const block_q8_K * GGML_RESTRICT y0 = y;
|
| 8542 |
+
const block_q8_K * GGML_RESTRICT y1 = (const block_q8_K *) ((const uint8_t *)vy + by);
|
| 8543 |
+
|
| 8544 |
+
float32x4_t vfsum = vdupq_n_f32(0.0f);
|
| 8545 |
+
|
| 8546 |
+
for (int i = 0; i < nb; ++i, ++x0, ++x1, ++y0, ++y1) {
|
| 8547 |
+
const uint8_t * GGML_RESTRICT ql0 = x0->ql;
|
| 8548 |
+
const uint8_t * GGML_RESTRICT ql1 = x1->ql;
|
| 8549 |
+
const uint8_t * GGML_RESTRICT qh0 = x0->qh;
|
| 8550 |
+
const uint8_t * GGML_RESTRICT qh1 = x1->qh;
|
| 8551 |
+
const int8_t * GGML_RESTRICT qy0 = y0->qs;
|
| 8552 |
+
const int8_t * GGML_RESTRICT qy1 = y1->qs;
|
| 8553 |
+
|
| 8554 |
+
const uint8x16_t mone = vdupq_n_u8(0x30);
|
| 8555 |
+
const uint8x16_t m4b = vdupq_n_u8(0x0f);
|
| 8556 |
+
|
| 8557 |
+
int32x4_t visum = vdupq_n_s32(0);
|
| 8558 |
+
|
| 8559 |
+
// process 8 blocks per iteration, totally 16 blocks
|
| 8560 |
+
for (int j = 0; j < 2; ++j, qh0 += 32, ql0 += 64, qh1 += 32, ql1 += 64) {
|
| 8561 |
+
int8x16_t vx0[8], vx1[8];
|
| 8562 |
+
|
| 8563 |
+
// de-quantize vx0[8]
|
| 8564 |
+
{
|
| 8565 |
+
const uint8x16x2_t qh_bits = vld1q_u8_x2(qh0);
|
| 8566 |
+
const uint8x16x4_t ql_bits = vld1q_u8_x4(ql0);
|
| 8567 |
+
|
| 8568 |
+
uint8x16_t q6h_0 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 4));
|
| 8569 |
+
uint8x16_t q6h_1 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 4));
|
| 8570 |
+
uint8x16_t q6h_2 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 2));
|
| 8571 |
+
uint8x16_t q6h_3 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 2));
|
| 8572 |
+
|
| 8573 |
+
vx0[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[0], m4b), q6h_0));
|
| 8574 |
+
vx0[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[1], m4b), q6h_1));
|
| 8575 |
+
vx0[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[2], m4b), q6h_2));
|
| 8576 |
+
vx0[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[3], m4b), q6h_3));
|
| 8577 |
+
|
| 8578 |
+
q6h_0 = vandq_u8(mone, qh_bits.val[0]);
|
| 8579 |
+
q6h_1 = vandq_u8(mone, qh_bits.val[1]);
|
| 8580 |
+
q6h_2 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[0], 2));
|
| 8581 |
+
q6h_3 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[1], 2));
|
| 8582 |
+
|
| 8583 |
+
vx0[4] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[0], 4), q6h_0));
|
| 8584 |
+
vx0[5] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[1], 4), q6h_1));
|
| 8585 |
+
vx0[6] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[2], 4), q6h_2));
|
| 8586 |
+
vx0[7] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[3], 4), q6h_3));
|
| 8587 |
+
}
|
| 8588 |
+
|
| 8589 |
+
// de-quantize vx1[8]
|
| 8590 |
+
{
|
| 8591 |
+
const uint8x16x2_t qh_bits = vld1q_u8_x2(qh1);
|
| 8592 |
+
const uint8x16x4_t ql_bits = vld1q_u8_x4(ql1);
|
| 8593 |
+
|
| 8594 |
+
uint8x16_t q6h_0 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 4));
|
| 8595 |
+
uint8x16_t q6h_1 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 4));
|
| 8596 |
+
uint8x16_t q6h_2 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 2));
|
| 8597 |
+
uint8x16_t q6h_3 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 2));
|
| 8598 |
+
|
| 8599 |
+
vx1[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[0], m4b), q6h_0));
|
| 8600 |
+
vx1[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[1], m4b), q6h_1));
|
| 8601 |
+
vx1[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[2], m4b), q6h_2));
|
| 8602 |
+
vx1[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[3], m4b), q6h_3));
|
| 8603 |
+
|
| 8604 |
+
q6h_0 = vandq_u8(mone, qh_bits.val[0]);
|
| 8605 |
+
q6h_1 = vandq_u8(mone, qh_bits.val[1]);
|
| 8606 |
+
q6h_2 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[0], 2));
|
| 8607 |
+
q6h_3 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[1], 2));
|
| 8608 |
+
|
| 8609 |
+
vx1[4] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[0], 4), q6h_0));
|
| 8610 |
+
vx1[5] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[1], 4), q6h_1));
|
| 8611 |
+
vx1[6] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[2], 4), q6h_2));
|
| 8612 |
+
vx1[7] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[3], 4), q6h_3));
|
| 8613 |
+
}
|
| 8614 |
+
|
| 8615 |
+
// process 16 elements (one block with same scale) per iteration
|
| 8616 |
+
// - vx = concat(ql, qh) - 32
|
| 8617 |
+
// - r1,r2,r3,r4 = smmla(vx, vy)
|
| 8618 |
+
for (int k = 0; k < 8; ++k) {
|
| 8619 |
+
const int blk = j * 8 + k;
|
| 8620 |
+
|
| 8621 |
+
const int8x16_t vy0 = vld1q_s8(qy0);
|
| 8622 |
+
const int8x16_t vy1 = vld1q_s8(qy1);
|
| 8623 |
+
qy0 += 16;
|
| 8624 |
+
qy1 += 16;
|
| 8625 |
+
|
| 8626 |
+
const int32x4_t block_scale = {
|
| 8627 |
+
x0->scales[blk],
|
| 8628 |
+
x0->scales[blk],
|
| 8629 |
+
x1->scales[blk],
|
| 8630 |
+
x1->scales[blk],
|
| 8631 |
+
};
|
| 8632 |
+
|
| 8633 |
+
// calculate four results at once with outer product
|
| 8634 |
+
const int8x16_t vx_l = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(vx0[k]), vreinterpretq_s64_s8(vx1[k])));
|
| 8635 |
+
const int8x16_t vx_h = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(vx0[k]), vreinterpretq_s64_s8(vx1[k])));
|
| 8636 |
+
const int8x16_t vy_l = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(vy0), vreinterpretq_s64_s8(vy1)));
|
| 8637 |
+
const int8x16_t vy_h = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(vy0), vreinterpretq_s64_s8(vy1)));
|
| 8638 |
+
int32x4_t vr = vdupq_n_s32(0);
|
| 8639 |
+
vr = vmmlaq_s32(vr, vx_l, vy_l);
|
| 8640 |
+
vr = vmmlaq_s32(vr, vx_h, vy_h);
|
| 8641 |
+
|
| 8642 |
+
// apply block scale, will NOT overflow
|
| 8643 |
+
// block_scale * sum_256(int6*int8) <= 2^(8+8+6+8) = 30 bits
|
| 8644 |
+
visum = vmlaq_s32(visum, vr, block_scale);
|
| 8645 |
+
}
|
| 8646 |
+
}
|
| 8647 |
+
|
| 8648 |
+
// adjust bias, apply superblock scale
|
| 8649 |
+
{
|
| 8650 |
+
int32_t bias[4];
|
| 8651 |
+
#ifdef __ARM_FEATURE_SVE
|
| 8652 |
+
const svbool_t pg16_8 = svptrue_pat_b16(SV_VL8);
|
| 8653 |
+
const svbool_t pg8_8 = svptrue_pat_b8(SV_VL8);
|
| 8654 |
+
const svint16_t y0_q8sums_0 = svld1_s16(pg16_8, y0->bsums);
|
| 8655 |
+
const svint16_t y0_q8sums_1 = svld1_s16(pg16_8, y0->bsums + 8);
|
| 8656 |
+
const svint16_t y1_q8sums_0 = svld1_s16(pg16_8, y1->bsums);
|
| 8657 |
+
const svint16_t y1_q8sums_1 = svld1_s16(pg16_8, y1->bsums + 8);
|
| 8658 |
+
const svint16_t x0_q6scales_0 = svunpklo_s16(svld1_s8(pg8_8, x0->scales));
|
| 8659 |
+
const svint16_t x0_q6scales_1 = svunpklo_s16(svld1_s8(pg8_8, x0->scales + 8));
|
| 8660 |
+
const svint16_t x1_q6scales_0 = svunpklo_s16(svld1_s8(pg8_8, x1->scales));
|
| 8661 |
+
const svint16_t x1_q6scales_1 = svunpklo_s16(svld1_s8(pg8_8, x1->scales + 8));
|
| 8662 |
+
const svint64_t zero = svdup_n_s64(0);
|
| 8663 |
+
bias[0] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y0_q8sums_0, x0_q6scales_0),
|
| 8664 |
+
svdot_s64(zero, y0_q8sums_1, x0_q6scales_1)));
|
| 8665 |
+
bias[1] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y1_q8sums_0, x0_q6scales_0),
|
| 8666 |
+
svdot_s64(zero, y1_q8sums_1, x0_q6scales_1)));
|
| 8667 |
+
bias[2] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y0_q8sums_0, x1_q6scales_0),
|
| 8668 |
+
svdot_s64(zero, y0_q8sums_1, x1_q6scales_1)));
|
| 8669 |
+
bias[3] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y1_q8sums_0, x1_q6scales_0),
|
| 8670 |
+
svdot_s64(zero, y1_q8sums_1, x1_q6scales_1)));
|
| 8671 |
+
#else
|
| 8672 |
+
// NEON doesn't support int16 dot product, fallback to separated mul and add
|
| 8673 |
+
const int16x8x2_t q8sums0 = vld1q_s16_x2(y0->bsums);
|
| 8674 |
+
const int16x8x2_t q8sums1 = vld1q_s16_x2(y1->bsums);
|
| 8675 |
+
|
| 8676 |
+
int8x16_t scales_s8 = vld1q_s8(x0->scales);
|
| 8677 |
+
const int16x8x2_t q6scales0 = {{vmovl_s8(vget_low_s8(scales_s8)), vmovl_s8(vget_high_s8(scales_s8))}};
|
| 8678 |
+
scales_s8 = vld1q_s8(x1->scales);
|
| 8679 |
+
const int16x8x2_t q6scales1 = {{vmovl_s8(vget_low_s8(scales_s8)), vmovl_s8(vget_high_s8(scales_s8))}};
|
| 8680 |
+
|
| 8681 |
+
int32x4_t prod;
|
| 8682 |
+
prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[0]), vget_low_s16 (q6scales0.val[0])),
|
| 8683 |
+
vmull_s16(vget_high_s16(q8sums0.val[0]), vget_high_s16(q6scales0.val[0]))),
|
| 8684 |
+
vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[1]), vget_low_s16 (q6scales0.val[1])),
|
| 8685 |
+
vmull_s16(vget_high_s16(q8sums0.val[1]), vget_high_s16(q6scales0.val[1]))));
|
| 8686 |
+
bias[0] = vaddvq_s32(prod);
|
| 8687 |
+
prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[0]), vget_low_s16 (q6scales0.val[0])),
|
| 8688 |
+
vmull_s16(vget_high_s16(q8sums1.val[0]), vget_high_s16(q6scales0.val[0]))),
|
| 8689 |
+
vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[1]), vget_low_s16 (q6scales0.val[1])),
|
| 8690 |
+
vmull_s16(vget_high_s16(q8sums1.val[1]), vget_high_s16(q6scales0.val[1]))));
|
| 8691 |
+
bias[1] = vaddvq_s32(prod);
|
| 8692 |
+
prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[0]), vget_low_s16 (q6scales1.val[0])),
|
| 8693 |
+
vmull_s16(vget_high_s16(q8sums0.val[0]), vget_high_s16(q6scales1.val[0]))),
|
| 8694 |
+
vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[1]), vget_low_s16 (q6scales1.val[1])),
|
| 8695 |
+
vmull_s16(vget_high_s16(q8sums0.val[1]), vget_high_s16(q6scales1.val[1]))));
|
| 8696 |
+
bias[2] = vaddvq_s32(prod);
|
| 8697 |
+
prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[0]), vget_low_s16 (q6scales1.val[0])),
|
| 8698 |
+
vmull_s16(vget_high_s16(q8sums1.val[0]), vget_high_s16(q6scales1.val[0]))),
|
| 8699 |
+
vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[1]), vget_low_s16 (q6scales1.val[1])),
|
| 8700 |
+
vmull_s16(vget_high_s16(q8sums1.val[1]), vget_high_s16(q6scales1.val[1]))));
|
| 8701 |
+
bias[3] = vaddvq_s32(prod);
|
| 8702 |
+
|
| 8703 |
+
#endif
|
| 8704 |
+
const int32x4_t vibias = vmulq_n_s32(vld1q_s32(bias), 32);
|
| 8705 |
+
|
| 8706 |
+
const float32x4_t superblock_scale = {
|
| 8707 |
+
GGML_FP16_TO_FP32(x0->d) * y0->d,
|
| 8708 |
+
GGML_FP16_TO_FP32(x0->d) * y1->d,
|
| 8709 |
+
GGML_FP16_TO_FP32(x1->d) * y0->d,
|
| 8710 |
+
GGML_FP16_TO_FP32(x1->d) * y1->d,
|
| 8711 |
+
};
|
| 8712 |
+
|
| 8713 |
+
visum = vsubq_s32(visum, vibias);
|
| 8714 |
+
vfsum = vmlaq_f32(vfsum, vcvtq_f32_s32(visum), superblock_scale);
|
| 8715 |
+
}
|
| 8716 |
+
}
|
| 8717 |
+
|
| 8718 |
+
// vfsum = ABCD -> ACBD
|
| 8719 |
+
// AC -> s, BD -> (s+bs)
|
| 8720 |
+
vfsum = vzip1q_f32(vfsum, vextq_f32(vfsum, vfsum, 2));
|
| 8721 |
+
vst1_f32(s, vget_low_f32 (vfsum));
|
| 8722 |
+
vst1_f32(s + bs, vget_high_f32(vfsum));
|
| 8723 |
+
|
| 8724 |
+
return;
|
| 8725 |
+
}
|
| 8726 |
+
#endif
|
| 8727 |
+
|
| 8728 |
#ifdef __ARM_FEATURE_SVE
|
| 8729 |
const int vector_length = ggml_cpu_get_sve_cnt()*8;
|
| 8730 |
float sum = 0;
|
|
@@ -282,7 +282,11 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
|
|
| 282 |
.from_float = quantize_row_q6_K,
|
| 283 |
.vec_dot = ggml_vec_dot_q6_K_q8_K,
|
| 284 |
.vec_dot_type = GGML_TYPE_Q8_K,
|
|
|
|
|
|
|
|
|
|
| 285 |
.nrows = 1,
|
|
|
|
| 286 |
},
|
| 287 |
[GGML_TYPE_IQ2_XXS] = {
|
| 288 |
.from_float = NULL,
|
|
|
|
| 282 |
.from_float = quantize_row_q6_K,
|
| 283 |
.vec_dot = ggml_vec_dot_q6_K_q8_K,
|
| 284 |
.vec_dot_type = GGML_TYPE_Q8_K,
|
| 285 |
+
#if defined (__ARM_FEATURE_MATMUL_INT8)
|
| 286 |
+
.nrows = 2,
|
| 287 |
+
#else
|
| 288 |
.nrows = 1,
|
| 289 |
+
#endif
|
| 290 |
},
|
| 291 |
[GGML_TYPE_IQ2_XXS] = {
|
| 292 |
.from_float = NULL,
|