ggerganov commited on
Commit
0c552df
·
unverified ·
1 Parent(s): 16dc72c

ggml : fix UB in IQ2_S and IQ3_S (llama/6012)

Browse files
Files changed (1) hide show
  1. ggml-quants.c +6 -6
ggml-quants.c CHANGED
@@ -9025,7 +9025,7 @@ void ggml_vec_dot_iq2_s_q8_K(int n, float * restrict s, size_t bs, const void *
9025
  vld1_s8((const int8_t *)(iq2s_grid + (qs[7] | ((qh[ib32+1] << 2) & 0x300)))));
9026
  qs += 8;
9027
 
9028
- vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | (signs[1] << 16)));
9029
  vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
9030
  vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
9031
  vs.val[0] = vceqq_u8(vs.val[0], mask2);
@@ -9034,7 +9034,7 @@ void ggml_vec_dot_iq2_s_q8_K(int n, float * restrict s, size_t bs, const void *
9034
  q2s.val[0] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[0], m1)), q2s.val[0]);
9035
  q2s.val[1] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[1], m1)), q2s.val[1]);
9036
 
9037
- vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[2] | (signs[3] << 16)));
9038
  vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
9039
  vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
9040
  vs.val[0] = vceqq_u8(vs.val[0], mask2);
@@ -9105,12 +9105,12 @@ void ggml_vec_dot_iq2_s_q8_K(int n, float * restrict s, size_t bs, const void *
9105
  iq2s_grid[qs[4] | ((qh[ib32+1] << 8) & 0x300)]);
9106
  qs += 8;
9107
 
9108
- __m256i aux256 = _mm256_set1_epi32(signs[0] | (signs[1] << 16));
9109
  aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2);
9110
  const __m256i s2_1 = _mm256_cmpeq_epi8(aux256, mask2);
9111
  const __m256i q8s_1 = _mm256_sub_epi8(_mm256_xor_si256(s2_1, q8_1), s2_1);
9112
 
9113
- aux256 = _mm256_set1_epi32(signs[2] | (signs[3] << 16));
9114
  aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2);
9115
  const __m256i s2_2 = _mm256_cmpeq_epi8(aux256, mask2);
9116
  const __m256i q8s_2 = _mm256_sub_epi8(_mm256_xor_si256(s2_2, q8_2), s2_2);
@@ -9386,7 +9386,7 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * restrict s, size_t bs, const void *
9386
  iq3s_grid[idx.index[6]], iq3s_grid[idx.index[7]]);
9387
 
9388
 
9389
- vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | (signs[1] << 16)));
9390
  vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
9391
  vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
9392
  vs.val[0] = vorrq_u8(vceqq_u8(vs.val[0], mask2), m1);
@@ -9395,7 +9395,7 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * restrict s, size_t bs, const void *
9395
  q3s.val[0] = vmulq_s8(vreinterpretq_s8_u8(vs.val[0]), vreinterpretq_s8_u32(aux32x4_0));
9396
  q3s.val[1] = vmulq_s8(vreinterpretq_s8_u8(vs.val[1]), vreinterpretq_s8_u32(aux32x4_1));
9397
 
9398
- vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[2] | (signs[3] << 16)));
9399
  vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
9400
  vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
9401
  vs.val[0] = vorrq_u8(vceqq_u8(vs.val[0], mask2), m1);
 
9025
  vld1_s8((const int8_t *)(iq2s_grid + (qs[7] | ((qh[ib32+1] << 2) & 0x300)))));
9026
  qs += 8;
9027
 
9028
+ vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | ((uint32_t) signs[1] << 16)));
9029
  vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
9030
  vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
9031
  vs.val[0] = vceqq_u8(vs.val[0], mask2);
 
9034
  q2s.val[0] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[0], m1)), q2s.val[0]);
9035
  q2s.val[1] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[1], m1)), q2s.val[1]);
9036
 
9037
+ vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[2] | ((uint32_t) signs[3] << 16)));
9038
  vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
9039
  vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
9040
  vs.val[0] = vceqq_u8(vs.val[0], mask2);
 
9105
  iq2s_grid[qs[4] | ((qh[ib32+1] << 8) & 0x300)]);
9106
  qs += 8;
9107
 
9108
+ __m256i aux256 = _mm256_set1_epi32(signs[0] | ((uint32_t) signs[1] << 16));
9109
  aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2);
9110
  const __m256i s2_1 = _mm256_cmpeq_epi8(aux256, mask2);
9111
  const __m256i q8s_1 = _mm256_sub_epi8(_mm256_xor_si256(s2_1, q8_1), s2_1);
9112
 
9113
+ aux256 = _mm256_set1_epi32(signs[2] | ((uint32_t) signs[3] << 16));
9114
  aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2);
9115
  const __m256i s2_2 = _mm256_cmpeq_epi8(aux256, mask2);
9116
  const __m256i q8s_2 = _mm256_sub_epi8(_mm256_xor_si256(s2_2, q8_2), s2_2);
 
9386
  iq3s_grid[idx.index[6]], iq3s_grid[idx.index[7]]);
9387
 
9388
 
9389
+ vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | ((uint32_t) signs[1] << 16)));
9390
  vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
9391
  vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
9392
  vs.val[0] = vorrq_u8(vceqq_u8(vs.val[0], mask2), m1);
 
9395
  q3s.val[0] = vmulq_s8(vreinterpretq_s8_u8(vs.val[0]), vreinterpretq_s8_u32(aux32x4_0));
9396
  q3s.val[1] = vmulq_s8(vreinterpretq_s8_u8(vs.val[1]), vreinterpretq_s8_u32(aux32x4_1));
9397
 
9398
+ vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[2] | ((uint32_t) signs[3] << 16)));
9399
  vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
9400
  vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
9401
  vs.val[0] = vorrq_u8(vceqq_u8(vs.val[0], mask2), m1);