Spaces:
Running
Running
ggml : fix UB in IQ2_S and IQ3_S (llama/6012)
Browse files- ggml-quants.c +6 -6
ggml-quants.c
CHANGED
|
@@ -9025,7 +9025,7 @@ void ggml_vec_dot_iq2_s_q8_K(int n, float * restrict s, size_t bs, const void *
|
|
| 9025 |
vld1_s8((const int8_t *)(iq2s_grid + (qs[7] | ((qh[ib32+1] << 2) & 0x300)))));
|
| 9026 |
qs += 8;
|
| 9027 |
|
| 9028 |
-
vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | (signs[1] << 16)));
|
| 9029 |
vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
|
| 9030 |
vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
|
| 9031 |
vs.val[0] = vceqq_u8(vs.val[0], mask2);
|
|
@@ -9034,7 +9034,7 @@ void ggml_vec_dot_iq2_s_q8_K(int n, float * restrict s, size_t bs, const void *
|
|
| 9034 |
q2s.val[0] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[0], m1)), q2s.val[0]);
|
| 9035 |
q2s.val[1] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[1], m1)), q2s.val[1]);
|
| 9036 |
|
| 9037 |
-
vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[2] | (signs[3] << 16)));
|
| 9038 |
vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
|
| 9039 |
vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
|
| 9040 |
vs.val[0] = vceqq_u8(vs.val[0], mask2);
|
|
@@ -9105,12 +9105,12 @@ void ggml_vec_dot_iq2_s_q8_K(int n, float * restrict s, size_t bs, const void *
|
|
| 9105 |
iq2s_grid[qs[4] | ((qh[ib32+1] << 8) & 0x300)]);
|
| 9106 |
qs += 8;
|
| 9107 |
|
| 9108 |
-
__m256i aux256 = _mm256_set1_epi32(signs[0] | (signs[1] << 16));
|
| 9109 |
aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2);
|
| 9110 |
const __m256i s2_1 = _mm256_cmpeq_epi8(aux256, mask2);
|
| 9111 |
const __m256i q8s_1 = _mm256_sub_epi8(_mm256_xor_si256(s2_1, q8_1), s2_1);
|
| 9112 |
|
| 9113 |
-
aux256 = _mm256_set1_epi32(signs[2] | (signs[3] << 16));
|
| 9114 |
aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2);
|
| 9115 |
const __m256i s2_2 = _mm256_cmpeq_epi8(aux256, mask2);
|
| 9116 |
const __m256i q8s_2 = _mm256_sub_epi8(_mm256_xor_si256(s2_2, q8_2), s2_2);
|
|
@@ -9386,7 +9386,7 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * restrict s, size_t bs, const void *
|
|
| 9386 |
iq3s_grid[idx.index[6]], iq3s_grid[idx.index[7]]);
|
| 9387 |
|
| 9388 |
|
| 9389 |
-
vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | (signs[1] << 16)));
|
| 9390 |
vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
|
| 9391 |
vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
|
| 9392 |
vs.val[0] = vorrq_u8(vceqq_u8(vs.val[0], mask2), m1);
|
|
@@ -9395,7 +9395,7 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * restrict s, size_t bs, const void *
|
|
| 9395 |
q3s.val[0] = vmulq_s8(vreinterpretq_s8_u8(vs.val[0]), vreinterpretq_s8_u32(aux32x4_0));
|
| 9396 |
q3s.val[1] = vmulq_s8(vreinterpretq_s8_u8(vs.val[1]), vreinterpretq_s8_u32(aux32x4_1));
|
| 9397 |
|
| 9398 |
-
vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[2] | (signs[3] << 16)));
|
| 9399 |
vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
|
| 9400 |
vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
|
| 9401 |
vs.val[0] = vorrq_u8(vceqq_u8(vs.val[0], mask2), m1);
|
|
|
|
| 9025 |
vld1_s8((const int8_t *)(iq2s_grid + (qs[7] | ((qh[ib32+1] << 2) & 0x300)))));
|
| 9026 |
qs += 8;
|
| 9027 |
|
| 9028 |
+
vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | ((uint32_t) signs[1] << 16)));
|
| 9029 |
vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
|
| 9030 |
vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
|
| 9031 |
vs.val[0] = vceqq_u8(vs.val[0], mask2);
|
|
|
|
| 9034 |
q2s.val[0] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[0], m1)), q2s.val[0]);
|
| 9035 |
q2s.val[1] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[1], m1)), q2s.val[1]);
|
| 9036 |
|
| 9037 |
+
vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[2] | ((uint32_t) signs[3] << 16)));
|
| 9038 |
vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
|
| 9039 |
vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
|
| 9040 |
vs.val[0] = vceqq_u8(vs.val[0], mask2);
|
|
|
|
| 9105 |
iq2s_grid[qs[4] | ((qh[ib32+1] << 8) & 0x300)]);
|
| 9106 |
qs += 8;
|
| 9107 |
|
| 9108 |
+
__m256i aux256 = _mm256_set1_epi32(signs[0] | ((uint32_t) signs[1] << 16));
|
| 9109 |
aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2);
|
| 9110 |
const __m256i s2_1 = _mm256_cmpeq_epi8(aux256, mask2);
|
| 9111 |
const __m256i q8s_1 = _mm256_sub_epi8(_mm256_xor_si256(s2_1, q8_1), s2_1);
|
| 9112 |
|
| 9113 |
+
aux256 = _mm256_set1_epi32(signs[2] | ((uint32_t) signs[3] << 16));
|
| 9114 |
aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2);
|
| 9115 |
const __m256i s2_2 = _mm256_cmpeq_epi8(aux256, mask2);
|
| 9116 |
const __m256i q8s_2 = _mm256_sub_epi8(_mm256_xor_si256(s2_2, q8_2), s2_2);
|
|
|
|
| 9386 |
iq3s_grid[idx.index[6]], iq3s_grid[idx.index[7]]);
|
| 9387 |
|
| 9388 |
|
| 9389 |
+
vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | ((uint32_t) signs[1] << 16)));
|
| 9390 |
vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
|
| 9391 |
vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
|
| 9392 |
vs.val[0] = vorrq_u8(vceqq_u8(vs.val[0], mask2), m1);
|
|
|
|
| 9395 |
q3s.val[0] = vmulq_s8(vreinterpretq_s8_u8(vs.val[0]), vreinterpretq_s8_u32(aux32x4_0));
|
| 9396 |
q3s.val[1] = vmulq_s8(vreinterpretq_s8_u8(vs.val[1]), vreinterpretq_s8_u32(aux32x4_1));
|
| 9397 |
|
| 9398 |
+
vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[2] | ((uint32_t) signs[3] << 16)));
|
| 9399 |
vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
|
| 9400 |
vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
|
| 9401 |
vs.val[0] = vorrq_u8(vceqq_u8(vs.val[0], mask2), m1);
|