Spaces:
Running
Running
ggml : activate s390x simd for Q3_K (llama/13301)
Browse filesSigned-off-by: Aaron Teo <[email protected]>
ggml/src/ggml-cpu/ggml-cpu-quants.c
CHANGED
|
@@ -6590,7 +6590,118 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
|
| 6590 |
}
|
| 6591 |
|
| 6592 |
*s = hsum_float_8(acc);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6593 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6594 |
#else
|
| 6595 |
// scalar version
|
| 6596 |
// This function is written like this so the compiler can manage to vectorize most of it
|
|
|
|
| 6590 |
}
|
| 6591 |
|
| 6592 |
*s = hsum_float_8(acc);
|
| 6593 |
+
#elif defined(__VXE__) || defined(__VXE2__)
|
| 6594 |
+
uint32_t aux[3];
|
| 6595 |
+
uint32_t utmp[4];
|
| 6596 |
+
|
| 6597 |
+
const int32x4_t v_z = vec_splat_s32(0);
|
| 6598 |
+
const uint8x16_t v_3m = vec_splat_u8(0x03);
|
| 6599 |
+
|
| 6600 |
+
const uint8x16_t v_0c = vec_splat_u8(1);
|
| 6601 |
+
const uint8x16_t v_1c = vec_sl(v_0c, 1);
|
| 6602 |
+
const uint8x16_t v_2c = vec_sl(v_0c, 2);
|
| 6603 |
+
const uint8x16_t v_3c = vec_sl(v_0c, 3);
|
| 6604 |
+
|
| 6605 |
+
uint8x16_t q3h[4];
|
| 6606 |
+
uint8x16_t q3b[2];
|
| 6607 |
+
int8x16_t q3bytes[4];
|
| 6608 |
+
int8x16_t q8bytes[4];
|
| 6609 |
+
uint8x16_t qhbits[2];
|
| 6610 |
+
|
| 6611 |
+
float sum = 0;
|
| 6612 |
+
|
| 6613 |
+
for (int i = 0; i < nb; ++i) {
|
| 6614 |
+
const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
|
| 6615 |
|
| 6616 |
+
const uint8_t * restrict x0l = x[i].qs;
|
| 6617 |
+
const uint8_t * restrict x0h = x[i].hmask;
|
| 6618 |
+
const int8_t * restrict y0 = y[i].qs;
|
| 6619 |
+
|
| 6620 |
+
qhbits[0] = vec_xl(0 , x0h);
|
| 6621 |
+
qhbits[1] = vec_xl(16, x0h);
|
| 6622 |
+
|
| 6623 |
+
int32_t isum = 0;
|
| 6624 |
+
|
| 6625 |
+
memcpy(aux, x[i].scales, 12);
|
| 6626 |
+
utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
|
| 6627 |
+
utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
|
| 6628 |
+
utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
|
| 6629 |
+
utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
|
| 6630 |
+
|
| 6631 |
+
int8_t * scale = (int8_t *)utmp;
|
| 6632 |
+
for (int j = 0; j < 16; ++j) scale[j] -= 32;
|
| 6633 |
+
|
| 6634 |
+
for (int j = 0; j < QK_K/128; ++j) {
|
| 6635 |
+
int32x4_t isum0, isum1, isum2, isum3;
|
| 6636 |
+
|
| 6637 |
+
q3b[0] = vec_xl(0 , x0l);
|
| 6638 |
+
q3b[1] = vec_xl(16, x0l);
|
| 6639 |
+
x0l += 32;
|
| 6640 |
+
|
| 6641 |
+
q8bytes[0] = vec_xl(0 , y0);
|
| 6642 |
+
q8bytes[1] = vec_xl(16 , y0);
|
| 6643 |
+
q8bytes[2] = vec_xl(32 , y0);
|
| 6644 |
+
q8bytes[3] = vec_xl(48 , y0);
|
| 6645 |
+
q8bytes[4] = vec_xl(64 , y0);
|
| 6646 |
+
q8bytes[5] = vec_xl(80 , y0);
|
| 6647 |
+
q8bytes[6] = vec_xl(96 , y0);
|
| 6648 |
+
q8bytes[7] = vec_xl(112, y0);
|
| 6649 |
+
y0 += 128;
|
| 6650 |
+
|
| 6651 |
+
q3h[0] = vec_sl(vec_andc(v_0c, qhbits[0]), 2);
|
| 6652 |
+
q3h[1] = vec_sl(vec_andc(v_0c, qhbits[1]), 2);
|
| 6653 |
+
q3h[2] = vec_sl(vec_andc(v_1c, qhbits[0]), 1);
|
| 6654 |
+
q3h[3] = vec_sl(vec_andc(v_1c, qhbits[1]), 1);
|
| 6655 |
+
|
| 6656 |
+
q3bytes[0] = vec_sub((int8x16_t)vec_and(q3b[0], v_3m), (int8x16_t)q3h[0]);
|
| 6657 |
+
q3bytes[1] = vec_sub((int8x16_t)vec_and(q3b[1], v_3m), (int8x16_t)q3h[1]);
|
| 6658 |
+
q3bytes[2] = vec_sub((int8x16_t)vec_and(vec_sr(q3b[0], 2), v_3m), (int8x16_t)q3h[2]);
|
| 6659 |
+
q3bytes[3] = vec_sub((int8x16_t)vec_and(vec_sr(q3b[1], 2), v_3m), (int8x16_t)q3h[3]);
|
| 6660 |
+
|
| 6661 |
+
isum0 = ggml_vec_dot(v_z, q3bytes[0], q8bytes[0]);
|
| 6662 |
+
isum1 = ggml_vec_dot(v_z, q3bytes[1], q8bytes[1]);
|
| 6663 |
+
isum2 = ggml_vec_dot(v_z, q3bytes[2], q8bytes[2]);
|
| 6664 |
+
isum3 = ggml_vec_dot(v_z, q3bytes[3], q8bytes[3]);
|
| 6665 |
+
|
| 6666 |
+
isum += (isum0[0] + isum0[1] + isum0[2] + isum0[3]) * scale[0];
|
| 6667 |
+
isum += (isum1[0] + isum1[1] + isum1[2] + isum1[3]) * scale[1];
|
| 6668 |
+
isum += (isum2[0] + isum2[1] + isum2[2] + isum2[3]) * scale[2];
|
| 6669 |
+
isum += (isum3[0] + isum3[1] + isum3[2] + isum3[3]) * scale[3];
|
| 6670 |
+
|
| 6671 |
+
scale += 4;
|
| 6672 |
+
|
| 6673 |
+
q3h[0] = vec_andc(v_2c, qhbits[0]);
|
| 6674 |
+
q3h[1] = vec_andc(v_2c, qhbits[1]);
|
| 6675 |
+
q3h[2] = vec_sr(vec_andc(v_3c, qhbits[0]), 1);
|
| 6676 |
+
q3h[3] = vec_sr(vec_andc(v_3c, qhbits[1]), 1);
|
| 6677 |
+
|
| 6678 |
+
q3bytes[0] = vec_sub((int8x16_t)vec_and(vec_sr(q3b[0], 4), v_3m), (int8x16_t)q3h[0]);
|
| 6679 |
+
q3bytes[1] = vec_sub((int8x16_t)vec_and(vec_sr(q3b[1], 4), v_3m), (int8x16_t)q3h[1]);
|
| 6680 |
+
q3bytes[2] = vec_sub((int8x16_t)vec_and(vec_sr(q3b[0], 6), v_3m), (int8x16_t)q3h[2]);
|
| 6681 |
+
q3bytes[3] = vec_sub((int8x16_t)vec_and(vec_sr(q3b[1], 6), v_3m), (int8x16_t)q3h[3]);
|
| 6682 |
+
|
| 6683 |
+
isum0 = ggml_vec_dot(v_z, q3bytes[0], q8bytes[4]);
|
| 6684 |
+
isum1 = ggml_vec_dot(v_z, q3bytes[1], q8bytes[5]);
|
| 6685 |
+
isum2 = ggml_vec_dot(v_z, q3bytes[2], q8bytes[6]);
|
| 6686 |
+
isum3 = ggml_vec_dot(v_z, q3bytes[3], q8bytes[7]);
|
| 6687 |
+
|
| 6688 |
+
isum += (isum0[0] + isum0[1] + isum0[2] + isum0[3]) * scale[0];
|
| 6689 |
+
isum += (isum1[0] + isum1[1] + isum1[2] + isum1[3]) * scale[1];
|
| 6690 |
+
isum += (isum2[0] + isum2[1] + isum2[2] + isum2[3]) * scale[2];
|
| 6691 |
+
isum += (isum3[0] + isum3[1] + isum3[2] + isum3[3]) * scale[3];
|
| 6692 |
+
|
| 6693 |
+
scale += 4;
|
| 6694 |
+
|
| 6695 |
+
if (j == 0) {
|
| 6696 |
+
qhbits[0] = vec_sr(qhbits[0], 4);
|
| 6697 |
+
qhbits[1] = vec_sr(qhbits[1], 4);
|
| 6698 |
+
}
|
| 6699 |
+
}
|
| 6700 |
+
|
| 6701 |
+
sum += d * isum;
|
| 6702 |
+
}
|
| 6703 |
+
|
| 6704 |
+
*s = sum;
|
| 6705 |
#else
|
| 6706 |
// scalar version
|
| 6707 |
// This function is written like this so the compiler can manage to vectorize most of it
|