Spaces:
Running
Running
ggml : add ggml_vdotq_s32 alias (llama/4715)
Browse files- ggml-quants.c +61 -57
ggml-quants.c
CHANGED
|
@@ -410,13 +410,17 @@ inline static ggml_int8x16x4_t ggml_vld1q_s8_x4(const int8_t * ptr) {
|
|
| 410 |
|
| 411 |
#if !defined(__ARM_FEATURE_DOTPROD)
|
| 412 |
|
| 413 |
-
inline static int32x4_t
|
| 414 |
const int16x8_t p0 = vmull_s8(vget_low_s8 (a), vget_low_s8 (b));
|
| 415 |
const int16x8_t p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b));
|
| 416 |
|
| 417 |
return vaddq_s32(acc, vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1)));
|
| 418 |
}
|
| 419 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 420 |
#endif
|
| 421 |
|
| 422 |
#endif
|
|
@@ -2481,8 +2485,8 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx,
|
|
| 2481 |
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
|
| 2482 |
|
| 2483 |
// dot product into int32x4_t
|
| 2484 |
-
const int32x4_t p_0 =
|
| 2485 |
-
const int32x4_t p_1 =
|
| 2486 |
|
| 2487 |
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
|
| 2488 |
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
|
|
@@ -2769,8 +2773,8 @@ void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restri
|
|
| 2769 |
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
|
| 2770 |
|
| 2771 |
// dot product into int32x4_t
|
| 2772 |
-
const int32x4_t p_0 =
|
| 2773 |
-
const int32x4_t p_1 =
|
| 2774 |
|
| 2775 |
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*y0->d);
|
| 2776 |
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*y1->d);
|
|
@@ -2936,11 +2940,11 @@ void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restri
|
|
| 2936 |
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
|
| 2937 |
|
| 2938 |
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
|
| 2939 |
-
|
| 2940 |
-
|
| 2941 |
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
|
| 2942 |
-
|
| 2943 |
-
|
| 2944 |
}
|
| 2945 |
|
| 2946 |
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
|
|
@@ -3228,11 +3232,11 @@ void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restri
|
|
| 3228 |
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
|
| 3229 |
|
| 3230 |
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
|
| 3231 |
-
|
| 3232 |
-
|
| 3233 |
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
|
| 3234 |
-
|
| 3235 |
-
|
| 3236 |
}
|
| 3237 |
|
| 3238 |
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs0 + summs1;
|
|
@@ -3483,12 +3487,12 @@ void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restri
|
|
| 3483 |
const int8x16_t y1_1 = vld1q_s8(y1->qs + 16);
|
| 3484 |
|
| 3485 |
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
|
| 3486 |
-
|
| 3487 |
-
|
| 3488 |
|
| 3489 |
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
|
| 3490 |
-
|
| 3491 |
-
|
| 3492 |
}
|
| 3493 |
|
| 3494 |
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
|
|
@@ -3598,8 +3602,8 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
|
|
| 3598 |
// We use this macro instead of a function call because for some reason
|
| 3599 |
// the code runs 2-3% slower, even if the function is declared inline
|
| 3600 |
#define MULTIPLY_ACCUM_WITH_SCALE(index)\
|
| 3601 |
-
isum += vaddvq_s32(
|
| 3602 |
-
isum += vaddvq_s32(
|
| 3603 |
|
| 3604 |
#define SHIFT_MULTIPLY_ACCUM_WITH_SCALE(shift, index)\
|
| 3605 |
q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;\
|
|
@@ -3973,10 +3977,10 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
|
|
| 3973 |
q2bytes.val[2] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 4), m3));
|
| 3974 |
q2bytes.val[3] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 6), m3));
|
| 3975 |
|
| 3976 |
-
isum1 += vaddvq_s32(
|
| 3977 |
-
isum2 += vaddvq_s32(
|
| 3978 |
-
isum1 += vaddvq_s32(
|
| 3979 |
-
isum2 += vaddvq_s32(
|
| 3980 |
|
| 3981 |
sum += d * (isum1 + isum2);
|
| 3982 |
}
|
|
@@ -4256,10 +4260,10 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
|
|
| 4256 |
q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 2), m3b)), vreinterpretq_s8_u8(q3h.val[2]));
|
| 4257 |
q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 2), m3b)), vreinterpretq_s8_u8(q3h.val[3]));
|
| 4258 |
|
| 4259 |
-
isum += vaddvq_s32(
|
| 4260 |
-
isum += vaddvq_s32(
|
| 4261 |
-
isum += vaddvq_s32(
|
| 4262 |
-
isum += vaddvq_s32(
|
| 4263 |
|
| 4264 |
scale += 4;
|
| 4265 |
|
|
@@ -4273,10 +4277,10 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
|
|
| 4273 |
q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 6), m3b)), vreinterpretq_s8_u8(q3h.val[2]));
|
| 4274 |
q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 6), m3b)), vreinterpretq_s8_u8(q3h.val[3]));
|
| 4275 |
|
| 4276 |
-
isum += vaddvq_s32(
|
| 4277 |
-
isum += vaddvq_s32(
|
| 4278 |
-
isum += vaddvq_s32(
|
| 4279 |
-
isum += vaddvq_s32(
|
| 4280 |
|
| 4281 |
scale += 4;
|
| 4282 |
|
|
@@ -4757,10 +4761,10 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
|
|
| 4757 |
q3bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(vshrq_n_u8(q3bits, 4), m3b), q3h.val[2]));
|
| 4758 |
q3bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q3bits, 6), q3h.val[3]));
|
| 4759 |
|
| 4760 |
-
isum += vaddvq_s32(
|
| 4761 |
-
isum += vaddvq_s32(
|
| 4762 |
-
isum += vaddvq_s32(
|
| 4763 |
-
isum += vaddvq_s32(
|
| 4764 |
|
| 4765 |
sum += d * isum;
|
| 4766 |
|
|
@@ -5109,14 +5113,14 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
|
|
| 5109 |
q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
|
| 5110 |
q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
|
| 5111 |
|
| 5112 |
-
const int32x4_t p1 =
|
| 5113 |
sumi1 += vaddvq_s32(p1) * scales[2*j+0];
|
| 5114 |
|
| 5115 |
q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
|
| 5116 |
q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
|
| 5117 |
q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
|
| 5118 |
|
| 5119 |
-
const int32x4_t p2 =
|
| 5120 |
|
| 5121 |
sumi2 += vaddvq_s32(p2) * scales[2*j+1];
|
| 5122 |
}
|
|
@@ -5449,13 +5453,13 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
|
|
| 5449 |
q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
|
| 5450 |
q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
|
| 5451 |
|
| 5452 |
-
const int32x4_t p1 =
|
| 5453 |
const int32_t sumi1 = vaddvq_s32(p1) * scales[0];
|
| 5454 |
|
| 5455 |
q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
|
| 5456 |
q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
|
| 5457 |
|
| 5458 |
-
const int32x4_t p2 =
|
| 5459 |
const int32_t sumi2 = vaddvq_s32(p2) * scales[1];
|
| 5460 |
|
| 5461 |
sumf += d * (sumi1 + sumi2);
|
|
@@ -5722,8 +5726,8 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
|
|
| 5722 |
q5bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[0], 4), q5h.val[2]));
|
| 5723 |
q5bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[1], 4), q5h.val[3]));
|
| 5724 |
|
| 5725 |
-
sumi += vaddvq_s32(
|
| 5726 |
-
sumi += vaddvq_s32(
|
| 5727 |
}
|
| 5728 |
|
| 5729 |
sumf += d * sumi - dmin * sumi_mins;
|
|
@@ -6112,10 +6116,10 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
|
|
| 6112 |
q5bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(q5bits.val[0], 4)), vreinterpretq_s8_u8(q5h.val[2]));
|
| 6113 |
q5bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(q5bits.val[1], 4)), vreinterpretq_s8_u8(q5h.val[3]));
|
| 6114 |
|
| 6115 |
-
int32_t sumi1 = sc[0] * vaddvq_s32(
|
| 6116 |
-
int32_t sumi2 = sc[1] * vaddvq_s32(
|
| 6117 |
-
int32_t sumi3 = sc[2] * vaddvq_s32(
|
| 6118 |
-
int32_t sumi4 = sc[3] * vaddvq_s32(
|
| 6119 |
|
| 6120 |
sumf += d * (sumi1 + sumi2 + sumi3 + sumi4);
|
| 6121 |
}
|
|
@@ -6399,10 +6403,10 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
|
|
| 6399 |
q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2]));
|
| 6400 |
q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3]));
|
| 6401 |
|
| 6402 |
-
isum += vaddvq_s32(
|
| 6403 |
-
vaddvq_s32(
|
| 6404 |
-
vaddvq_s32(
|
| 6405 |
-
vaddvq_s32(
|
| 6406 |
|
| 6407 |
scale += 4;
|
| 6408 |
|
|
@@ -6426,10 +6430,10 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
|
|
| 6426 |
q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2]));
|
| 6427 |
q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3]));
|
| 6428 |
|
| 6429 |
-
isum += vaddvq_s32(
|
| 6430 |
-
vaddvq_s32(
|
| 6431 |
-
vaddvq_s32(
|
| 6432 |
-
vaddvq_s32(
|
| 6433 |
scale += 4;
|
| 6434 |
}
|
| 6435 |
//sum += isum * d_all * y[i].d;
|
|
@@ -6816,10 +6820,10 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
|
|
| 6816 |
q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[2])), m32s);
|
| 6817 |
q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[3])), m32s);
|
| 6818 |
|
| 6819 |
-
isum += vaddvq_s32(
|
| 6820 |
-
vaddvq_s32(
|
| 6821 |
-
vaddvq_s32(
|
| 6822 |
-
vaddvq_s32(
|
| 6823 |
|
| 6824 |
sum += isum * d_all * y[i].d;
|
| 6825 |
|
|
|
|
| 410 |
|
| 411 |
#if !defined(__ARM_FEATURE_DOTPROD)
|
| 412 |
|
| 413 |
+
inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b) {
|
| 414 |
const int16x8_t p0 = vmull_s8(vget_low_s8 (a), vget_low_s8 (b));
|
| 415 |
const int16x8_t p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b));
|
| 416 |
|
| 417 |
return vaddq_s32(acc, vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1)));
|
| 418 |
}
|
| 419 |
|
| 420 |
+
#else
|
| 421 |
+
|
| 422 |
+
#define ggml_vdotq_s32(a, b, c) vdotq_s32(a, b, c)
|
| 423 |
+
|
| 424 |
#endif
|
| 425 |
|
| 426 |
#endif
|
|
|
|
| 2485 |
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
|
| 2486 |
|
| 2487 |
// dot product into int32x4_t
|
| 2488 |
+
const int32x4_t p_0 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0l), v0_0hs, v1_0h);
|
| 2489 |
+
const int32x4_t p_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1l), v0_1hs, v1_1h);
|
| 2490 |
|
| 2491 |
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
|
| 2492 |
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
|
|
|
|
| 2773 |
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
|
| 2774 |
|
| 2775 |
// dot product into int32x4_t
|
| 2776 |
+
const int32x4_t p_0 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l), v0_0h, v1_0h);
|
| 2777 |
+
const int32x4_t p_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l), v0_1h, v1_1h);
|
| 2778 |
|
| 2779 |
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*y0->d);
|
| 2780 |
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*y1->d);
|
|
|
|
| 2940 |
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
|
| 2941 |
|
| 2942 |
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
|
| 2943 |
+
ggml_vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
|
| 2944 |
+
ggml_vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
|
| 2945 |
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
|
| 2946 |
+
ggml_vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
|
| 2947 |
+
ggml_vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
|
| 2948 |
}
|
| 2949 |
|
| 2950 |
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
|
|
|
|
| 3232 |
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
|
| 3233 |
|
| 3234 |
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
|
| 3235 |
+
ggml_vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
|
| 3236 |
+
ggml_vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*y0->d);
|
| 3237 |
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
|
| 3238 |
+
ggml_vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
|
| 3239 |
+
ggml_vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*y1->d);
|
| 3240 |
}
|
| 3241 |
|
| 3242 |
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs0 + summs1;
|
|
|
|
| 3487 |
const int8x16_t y1_1 = vld1q_s8(y1->qs + 16);
|
| 3488 |
|
| 3489 |
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
|
| 3490 |
+
ggml_vdotq_s32(vdupq_n_s32(0), x0_0, y0_0),
|
| 3491 |
+
ggml_vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
|
| 3492 |
|
| 3493 |
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
|
| 3494 |
+
ggml_vdotq_s32(vdupq_n_s32(0), x1_0, y1_0),
|
| 3495 |
+
ggml_vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
|
| 3496 |
}
|
| 3497 |
|
| 3498 |
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
|
|
|
|
| 3602 |
// We use this macro instead of a function call because for some reason
|
| 3603 |
// the code runs 2-3% slower, even if the function is declared inline
|
| 3604 |
#define MULTIPLY_ACCUM_WITH_SCALE(index)\
|
| 3605 |
+
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * aux[is+(index)];\
|
| 3606 |
+
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * aux[is+1+(index)];
|
| 3607 |
|
| 3608 |
#define SHIFT_MULTIPLY_ACCUM_WITH_SCALE(shift, index)\
|
| 3609 |
q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;\
|
|
|
|
| 3977 |
q2bytes.val[2] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 4), m3));
|
| 3978 |
q2bytes.val[3] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 6), m3));
|
| 3979 |
|
| 3980 |
+
isum1 += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * scales[0];
|
| 3981 |
+
isum2 += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * scales[1];
|
| 3982 |
+
isum1 += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[2], q8bytes.val[2])) * scales[2];
|
| 3983 |
+
isum2 += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[3], q8bytes.val[3])) * scales[3];
|
| 3984 |
|
| 3985 |
sum += d * (isum1 + isum2);
|
| 3986 |
}
|
|
|
|
| 4260 |
q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 2), m3b)), vreinterpretq_s8_u8(q3h.val[2]));
|
| 4261 |
q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 2), m3b)), vreinterpretq_s8_u8(q3h.val[3]));
|
| 4262 |
|
| 4263 |
+
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[0], q8bytes_1.val[0])) * scale[0];
|
| 4264 |
+
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[1], q8bytes_1.val[1])) * scale[1];
|
| 4265 |
+
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[2], q8bytes_1.val[2])) * scale[2];
|
| 4266 |
+
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[3], q8bytes_1.val[3])) * scale[3];
|
| 4267 |
|
| 4268 |
scale += 4;
|
| 4269 |
|
|
|
|
| 4277 |
q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 6), m3b)), vreinterpretq_s8_u8(q3h.val[2]));
|
| 4278 |
q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 6), m3b)), vreinterpretq_s8_u8(q3h.val[3]));
|
| 4279 |
|
| 4280 |
+
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[0], q8bytes_2.val[0])) * scale[0];
|
| 4281 |
+
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[1], q8bytes_2.val[1])) * scale[1];
|
| 4282 |
+
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[2], q8bytes_2.val[2])) * scale[2];
|
| 4283 |
+
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[3], q8bytes_2.val[3])) * scale[3];
|
| 4284 |
|
| 4285 |
scale += 4;
|
| 4286 |
|
|
|
|
| 4761 |
q3bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(vshrq_n_u8(q3bits, 4), m3b), q3h.val[2]));
|
| 4762 |
q3bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q3bits, 6), q3h.val[3]));
|
| 4763 |
|
| 4764 |
+
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[0], q8bytes.val[0])) * scales[0];
|
| 4765 |
+
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[1], q8bytes.val[1])) * scales[2];
|
| 4766 |
+
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[2], q8bytes.val[2])) * scales[1];
|
| 4767 |
+
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[3], q8bytes.val[3])) * scales[3];
|
| 4768 |
|
| 4769 |
sum += d * isum;
|
| 4770 |
|
|
|
|
| 5113 |
q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
|
| 5114 |
q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
|
| 5115 |
|
| 5116 |
+
const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
|
| 5117 |
sumi1 += vaddvq_s32(p1) * scales[2*j+0];
|
| 5118 |
|
| 5119 |
q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
|
| 5120 |
q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
|
| 5121 |
q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
|
| 5122 |
|
| 5123 |
+
const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
|
| 5124 |
|
| 5125 |
sumi2 += vaddvq_s32(p2) * scales[2*j+1];
|
| 5126 |
}
|
|
|
|
| 5453 |
q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
|
| 5454 |
q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
|
| 5455 |
|
| 5456 |
+
const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
|
| 5457 |
const int32_t sumi1 = vaddvq_s32(p1) * scales[0];
|
| 5458 |
|
| 5459 |
q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
|
| 5460 |
q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
|
| 5461 |
|
| 5462 |
+
const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[2]), q4bytes.val[1], q8bytes.val[3]);
|
| 5463 |
const int32_t sumi2 = vaddvq_s32(p2) * scales[1];
|
| 5464 |
|
| 5465 |
sumf += d * (sumi1 + sumi2);
|
|
|
|
| 5726 |
q5bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[0], 4), q5h.val[2]));
|
| 5727 |
q5bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[1], 4), q5h.val[3]));
|
| 5728 |
|
| 5729 |
+
sumi += vaddvq_s32(ggml_vdotq_s32(ggml_vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]), q5bytes.val[1], q8bytes.val[1])) * *scales++;
|
| 5730 |
+
sumi += vaddvq_s32(ggml_vdotq_s32(ggml_vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]), q5bytes.val[3], q8bytes.val[3])) * *scales++;
|
| 5731 |
}
|
| 5732 |
|
| 5733 |
sumf += d * sumi - dmin * sumi_mins;
|
|
|
|
| 6116 |
q5bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(q5bits.val[0], 4)), vreinterpretq_s8_u8(q5h.val[2]));
|
| 6117 |
q5bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(q5bits.val[1], 4)), vreinterpretq_s8_u8(q5h.val[3]));
|
| 6118 |
|
| 6119 |
+
int32_t sumi1 = sc[0] * vaddvq_s32(ggml_vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]));
|
| 6120 |
+
int32_t sumi2 = sc[1] * vaddvq_s32(ggml_vdotq_s32(mzero, q5bytes.val[1], q8bytes.val[1]));
|
| 6121 |
+
int32_t sumi3 = sc[2] * vaddvq_s32(ggml_vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]));
|
| 6122 |
+
int32_t sumi4 = sc[3] * vaddvq_s32(ggml_vdotq_s32(mzero, q5bytes.val[3], q8bytes.val[3]));
|
| 6123 |
|
| 6124 |
sumf += d * (sumi1 + sumi2 + sumi3 + sumi4);
|
| 6125 |
}
|
|
|
|
| 6403 |
q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2]));
|
| 6404 |
q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3]));
|
| 6405 |
|
| 6406 |
+
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
|
| 6407 |
+
vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
|
| 6408 |
+
vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
|
| 6409 |
+
vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
|
| 6410 |
|
| 6411 |
scale += 4;
|
| 6412 |
|
|
|
|
| 6430 |
q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2]));
|
| 6431 |
q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3]));
|
| 6432 |
|
| 6433 |
+
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
|
| 6434 |
+
vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
|
| 6435 |
+
vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
|
| 6436 |
+
vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
|
| 6437 |
scale += 4;
|
| 6438 |
}
|
| 6439 |
//sum += isum * d_all * y[i].d;
|
|
|
|
| 6820 |
q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[2])), m32s);
|
| 6821 |
q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[3])), m32s);
|
| 6822 |
|
| 6823 |
+
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
|
| 6824 |
+
vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
|
| 6825 |
+
vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
|
| 6826 |
+
vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
|
| 6827 |
|
| 6828 |
sum += isum * d_all * y[i].d;
|
| 6829 |
|