Spaces:
Sleeping
arm64: optimize q4_k_q8_k kernel with i8mm (llama/13886)
Browse filesThis PR improves q4_k_q8_k gemm kernel with arm64 i8mm instruction.
Tested on neoverse-n2 with llama3 8b q4_k_m quantization model.
- 34% ~ 50% S_PP uplift for all batch sizes
- 12% ~ 37% S_TG uplift for batch size 4 and above
Perplexity doesn't change with this PR.
```
// tested on neoverse-n2
$ llama-batched-bench \
-m Meta-Llama-3-8B-Instruct-Q4_K_M.gguf \
--no-mmap -fa \
-c 8192 -b 4096 -ub 512 -npp 128 -ntg 128 \
-npl 1,2,4,8,16,32 \
-t 64
---------------------------------------------------------------------
| PP | TG | B | S_PP t/s | S_TG t/s |
| | | | original | this pr | original | this pr |
|-------|--------|------|----------|----------|----------|----------|
| 128 | 128 | 1 | 110.12 | 147.83 | 24.36 | 24.28 |
| 128 | 128 | 2 | 121.16 | 172.42 | 46.36 | 47.93 |
| 128 | 128 | 4 | 120.15 | 169.75 | 74.68 | 84.00 |
| 128 | 128 | 8 | 130.97 | 196.81 | 91.04 | 114.74 |
| 128 | 128 | 16 | 131.01 | 196.88 | 101.43 | 135.79 |
| 128 | 128 | 32 | 130.85 | 196.51 | 106.97 | 147.29 |
---------------------------------------------------------------------
```
|
@@ -6995,7 +6995,11 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
|
| 6995 |
|
| 6996 |
void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
| 6997 |
assert(n % QK_K == 0);
|
|
|
|
|
|
|
|
|
|
| 6998 |
assert(nrc == 1);
|
|
|
|
| 6999 |
UNUSED(nrc);
|
| 7000 |
UNUSED(bx);
|
| 7001 |
UNUSED(by);
|
|
@@ -7012,6 +7016,146 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
|
| 7012 |
|
| 7013 |
uint32_t utmp[4];
|
| 7014 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7015 |
#ifdef __ARM_FEATURE_SVE
|
| 7016 |
float sumf = 0;
|
| 7017 |
for (int i = 0; i < nb; ++i) {
|
|
|
|
| 6995 |
|
| 6996 |
void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
| 6997 |
assert(n % QK_K == 0);
|
| 6998 |
+
#ifdef __ARM_FEATURE_MATMUL_INT8
|
| 6999 |
+
assert((nrc == 2) || (nrc == 1));
|
| 7000 |
+
#else
|
| 7001 |
assert(nrc == 1);
|
| 7002 |
+
#endif
|
| 7003 |
UNUSED(nrc);
|
| 7004 |
UNUSED(bx);
|
| 7005 |
UNUSED(by);
|
|
|
|
| 7016 |
|
| 7017 |
uint32_t utmp[4];
|
| 7018 |
|
| 7019 |
+
#if defined(__ARM_FEATURE_MATMUL_INT8)
|
| 7020 |
+
if (nrc == 2) {
|
| 7021 |
+
const block_q4_K * GGML_RESTRICT x0 = x;
|
| 7022 |
+
const block_q4_K * GGML_RESTRICT x1 = (const block_q4_K *) ((const uint8_t *)vx + bx);
|
| 7023 |
+
const block_q8_K * GGML_RESTRICT y0 = y;
|
| 7024 |
+
const block_q8_K * GGML_RESTRICT y1 = (const block_q8_K *) ((const uint8_t *)vy + by);
|
| 7025 |
+
|
| 7026 |
+
const uint8x16_t m4b = vdupq_n_u8(0x0f);
|
| 7027 |
+
|
| 7028 |
+
float32x4_t vfsum = vdupq_n_f32(0.0f);
|
| 7029 |
+
|
| 7030 |
+
for (int i = 0; i < nb; ++i, ++x0, ++x1, ++y0, ++y1) {
|
| 7031 |
+
const uint8_t * GGML_RESTRICT qx0 = x0->qs;
|
| 7032 |
+
const uint8_t * GGML_RESTRICT qx1 = x1->qs;
|
| 7033 |
+
const int8_t * GGML_RESTRICT qy0 = y0->qs;
|
| 7034 |
+
const int8_t * GGML_RESTRICT qy1 = y1->qs;
|
| 7035 |
+
|
| 7036 |
+
// decode scales and mins
|
| 7037 |
+
int8_t x0_scales[8], x1_scales[8];
|
| 7038 |
+
int16x8_t x0_mins, x1_mins;
|
| 7039 |
+
{
|
| 7040 |
+
uint32_t scales_mins[3];
|
| 7041 |
+
memcpy(scales_mins, x0->scales, 12);
|
| 7042 |
+
const uint32_t mins_0_3 = scales_mins[1] & kmask1;
|
| 7043 |
+
const uint32_t mins_4_7 = ((scales_mins[2] >> 4) & kmask2) | (((scales_mins[1] >> 6) & kmask3) << 4);
|
| 7044 |
+
const uint32x2_t mins = {mins_0_3, mins_4_7};
|
| 7045 |
+
x0_mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins)));
|
| 7046 |
+
uint32_t scales[2];
|
| 7047 |
+
scales[0] = scales_mins[0] & kmask1; // scales 0~3
|
| 7048 |
+
scales[1] = (scales_mins[2] & kmask2) | (((scales_mins[0] >> 6) & kmask3) << 4); // scales 4~7
|
| 7049 |
+
memcpy(x0_scales, scales, 8);
|
| 7050 |
+
}
|
| 7051 |
+
{
|
| 7052 |
+
uint32_t scales_mins[3];
|
| 7053 |
+
memcpy(scales_mins, x1->scales, 12);
|
| 7054 |
+
const uint32_t mins_0_3 = scales_mins[1] & kmask1;
|
| 7055 |
+
const uint32_t mins_4_7 = ((scales_mins[2] >> 4) & kmask2) | (((scales_mins[1] >> 6) & kmask3) << 4);
|
| 7056 |
+
const uint32x2_t mins = {mins_0_3, mins_4_7};
|
| 7057 |
+
x1_mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins)));
|
| 7058 |
+
uint32_t scales[2];
|
| 7059 |
+
scales[0] = scales_mins[0] & kmask1; // scales 0~3
|
| 7060 |
+
scales[1] = (scales_mins[2] & kmask2) | (((scales_mins[0] >> 6) & kmask3) << 4); // scales 4~7
|
| 7061 |
+
memcpy(x1_scales, scales, 8);
|
| 7062 |
+
}
|
| 7063 |
+
|
| 7064 |
+
int32x4_t visum = {0};
|
| 7065 |
+
|
| 7066 |
+
// process 64 data points per iteration, totally 256 data points
|
| 7067 |
+
for (int j = 0; j < QK_K / 64; ++j, qx0 += 32, qx1 += 32, qy0 += 64, qy1 += 64) {
|
| 7068 |
+
const int8x16x4_t vy0 = vld1q_s8_x4(qy0);
|
| 7069 |
+
const int8x16x4_t vy1 = vld1q_s8_x4(qy1);
|
| 7070 |
+
|
| 7071 |
+
int8x16_t vx0[4], vx1[4];
|
| 7072 |
+
{
|
| 7073 |
+
const uint8x16x2_t vv = vld1q_u8_x2(qx0);
|
| 7074 |
+
vx0[0] = vreinterpretq_s8_u8(vandq_u8(vv.val[0], m4b));
|
| 7075 |
+
vx0[1] = vreinterpretq_s8_u8(vandq_u8(vv.val[1], m4b));
|
| 7076 |
+
vx0[2] = vreinterpretq_s8_u8(vshrq_n_u8(vv.val[0], 4));
|
| 7077 |
+
vx0[3] = vreinterpretq_s8_u8(vshrq_n_u8(vv.val[1], 4));
|
| 7078 |
+
}
|
| 7079 |
+
{
|
| 7080 |
+
const uint8x16x2_t vv = vld1q_u8_x2(qx1);
|
| 7081 |
+
vx1[0] = vreinterpretq_s8_u8(vandq_u8(vv.val[0], m4b));
|
| 7082 |
+
vx1[1] = vreinterpretq_s8_u8(vandq_u8(vv.val[1], m4b));
|
| 7083 |
+
vx1[2] = vreinterpretq_s8_u8(vshrq_n_u8(vv.val[0], 4));
|
| 7084 |
+
vx1[3] = vreinterpretq_s8_u8(vshrq_n_u8(vv.val[1], 4));
|
| 7085 |
+
}
|
| 7086 |
+
|
| 7087 |
+
// process 32 data points (share same block scale) per iteration
|
| 7088 |
+
for (int k = 0; k < 2; ++k) {
|
| 7089 |
+
const int blk = j * 2 + k;
|
| 7090 |
+
const int32x4_t block_scale = {
|
| 7091 |
+
x0_scales[blk],
|
| 7092 |
+
x0_scales[blk],
|
| 7093 |
+
x1_scales[blk],
|
| 7094 |
+
x1_scales[blk],
|
| 7095 |
+
};
|
| 7096 |
+
|
| 7097 |
+
int32x4_t vr = {0};
|
| 7098 |
+
for (int l = 0; l < 2; ++l) {
|
| 7099 |
+
const int idx = k * 2 + l;
|
| 7100 |
+
const int64x2_t vx0_s64 = vreinterpretq_s64_s8(vx0[idx]);
|
| 7101 |
+
const int64x2_t vx1_s64 = vreinterpretq_s64_s8(vx1[idx]);
|
| 7102 |
+
const int64x2_t vy0_s64 = vreinterpretq_s64_s8(vy0.val[idx]);
|
| 7103 |
+
const int64x2_t vy1_s64 = vreinterpretq_s64_s8(vy1.val[idx]);
|
| 7104 |
+
const int8x16_t vx_l = vreinterpretq_s8_s64(vzip1q_s64(vx0_s64, vx1_s64));
|
| 7105 |
+
const int8x16_t vx_h = vreinterpretq_s8_s64(vzip2q_s64(vx0_s64, vx1_s64));
|
| 7106 |
+
const int8x16_t vy_l = vreinterpretq_s8_s64(vzip1q_s64(vy0_s64, vy1_s64));
|
| 7107 |
+
const int8x16_t vy_h = vreinterpretq_s8_s64(vzip2q_s64(vy0_s64, vy1_s64));
|
| 7108 |
+
vr = vmmlaq_s32(vr, vx_l, vy_l);
|
| 7109 |
+
vr = vmmlaq_s32(vr, vx_h, vy_h);
|
| 7110 |
+
}
|
| 7111 |
+
// apply block scale, will NOT overflow
|
| 7112 |
+
// block_scale * sum_256(int4*int8) <= 2^(8+8+4+8) = 28 bits
|
| 7113 |
+
visum = vmlaq_s32(visum, vr, block_scale);
|
| 7114 |
+
}
|
| 7115 |
+
}
|
| 7116 |
+
|
| 7117 |
+
// adjust bias, apply superblock scale
|
| 7118 |
+
{
|
| 7119 |
+
int32_t bias[4];
|
| 7120 |
+
// no obvious uplift from sve sdot-16, just use neon mul add
|
| 7121 |
+
const int16x8_t y0_sums = vpaddq_s16(vld1q_s16(y0->bsums), vld1q_s16(y0->bsums+8));
|
| 7122 |
+
const int16x8_t y1_sums = vpaddq_s16(vld1q_s16(y1->bsums), vld1q_s16(y1->bsums+8));
|
| 7123 |
+
bias[0] = vaddvq_s32(vaddq_s32(vmull_s16(vget_low_s16(y0_sums), vget_low_s16(x0_mins)),
|
| 7124 |
+
vmull_s16(vget_high_s16(y0_sums), vget_high_s16(x0_mins))));
|
| 7125 |
+
bias[1] = vaddvq_s32(vaddq_s32(vmull_s16(vget_low_s16(y1_sums), vget_low_s16(x0_mins)),
|
| 7126 |
+
vmull_s16(vget_high_s16(y1_sums), vget_high_s16(x0_mins))));
|
| 7127 |
+
bias[2] = vaddvq_s32(vaddq_s32(vmull_s16(vget_low_s16(y0_sums), vget_low_s16(x1_mins)),
|
| 7128 |
+
vmull_s16(vget_high_s16(y0_sums), vget_high_s16(x1_mins))));
|
| 7129 |
+
bias[3] = vaddvq_s32(vaddq_s32(vmull_s16(vget_low_s16(y1_sums), vget_low_s16(x1_mins)),
|
| 7130 |
+
vmull_s16(vget_high_s16(y1_sums), vget_high_s16(x1_mins))));
|
| 7131 |
+
const float32x4_t dmins = {
|
| 7132 |
+
GGML_FP16_TO_FP32(x0->dmin) * y0->d,
|
| 7133 |
+
GGML_FP16_TO_FP32(x0->dmin) * y1->d,
|
| 7134 |
+
GGML_FP16_TO_FP32(x1->dmin) * y0->d,
|
| 7135 |
+
GGML_FP16_TO_FP32(x1->dmin) * y1->d,
|
| 7136 |
+
};
|
| 7137 |
+
vfsum = vmlsq_f32(vfsum, vcvtq_f32_s32(vld1q_s32(bias)), dmins);
|
| 7138 |
+
|
| 7139 |
+
const float32x4_t superblock_scale = {
|
| 7140 |
+
GGML_FP16_TO_FP32(x0->d) * y0->d,
|
| 7141 |
+
GGML_FP16_TO_FP32(x0->d) * y1->d,
|
| 7142 |
+
GGML_FP16_TO_FP32(x1->d) * y0->d,
|
| 7143 |
+
GGML_FP16_TO_FP32(x1->d) * y1->d,
|
| 7144 |
+
};
|
| 7145 |
+
vfsum = vmlaq_f32(vfsum, vcvtq_f32_s32(visum), superblock_scale);
|
| 7146 |
+
}
|
| 7147 |
+
}
|
| 7148 |
+
|
| 7149 |
+
// vfsum = ABCD -> ACBD
|
| 7150 |
+
// AC -> s, BD -> (s+bs)
|
| 7151 |
+
vfsum = vzip1q_f32(vfsum, vextq_f32(vfsum, vfsum, 2));
|
| 7152 |
+
vst1_f32(s, vget_low_f32 (vfsum));
|
| 7153 |
+
vst1_f32(s + bs, vget_high_f32(vfsum));
|
| 7154 |
+
|
| 7155 |
+
return;
|
| 7156 |
+
}
|
| 7157 |
+
#endif
|
| 7158 |
+
|
| 7159 |
#ifdef __ARM_FEATURE_SVE
|
| 7160 |
float sumf = 0;
|
| 7161 |
for (int i = 0; i < nb; ++i) {
|
|
@@ -270,7 +270,11 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
|
|
| 270 |
.from_float = quantize_row_q4_K,
|
| 271 |
.vec_dot = ggml_vec_dot_q4_K_q8_K,
|
| 272 |
.vec_dot_type = GGML_TYPE_Q8_K,
|
|
|
|
|
|
|
|
|
|
| 273 |
.nrows = 1,
|
|
|
|
| 274 |
},
|
| 275 |
[GGML_TYPE_Q5_K] = {
|
| 276 |
.from_float = quantize_row_q5_K,
|
|
|
|
| 270 |
.from_float = quantize_row_q4_K,
|
| 271 |
.vec_dot = ggml_vec_dot_q4_K_q8_K,
|
| 272 |
.vec_dot_type = GGML_TYPE_Q8_K,
|
| 273 |
+
#if defined (__ARM_FEATURE_MATMUL_INT8)
|
| 274 |
+
.nrows = 2,
|
| 275 |
+
#else
|
| 276 |
.nrows = 1,
|
| 277 |
+
#endif
|
| 278 |
},
|
| 279 |
[GGML_TYPE_Q5_K] = {
|
| 280 |
.from_float = quantize_row_q5_K,
|